From 899bc064e1bf8df915bcadc979a6f37210cdce33 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Sat, 17 Apr 2021 01:48:08 +0300 Subject: [PATCH] [METAL] Fix issue with GPU fails (#7819) * [METAL] Fix issue with GPU fails Added first run to auto scheduler. This run is necessary for checking that the generated kernel is correct. When we just run time evaluator with incorrect kernel then it is possible that our application on iOS device will be added to ignore list because of big number of committed incorrect kernels. One run before running auto scheduling helps us to avoid this problem. Added complete handlers to all command buffers in Metal runtime. It helps to handle GPU errors and report about this error to the host application. In case when error happened, we have to create a new stream. Added mechanism for error handling and streams creating from python interface. * Try to fix QEMU build * Apply comment * Apply comments and fix build * Apply comments and fix lint * Fix CI --- python/tvm/_ffi/runtime_ctypes.py | 46 ++++++++++++++++-- python/tvm/auto_scheduler/measure.py | 9 ++++ src/runtime/c_runtime_api.cc | 10 +--- src/runtime/crt/common/crt_runtime_api.c | 9 ++++ src/runtime/metal/metal_common.h | 45 +++++++++++++---- src/runtime/metal/metal_device_api.mm | 50 +++++++++++++++---- src/runtime/metal/metal_module.mm | 5 +- src/runtime/minrpc/minrpc_server.h | 62 ++++++++++++++++++++++++ src/runtime/minrpc/rpc_reference.h | 9 ++++ src/runtime/rpc/rpc_device_api.cc | 15 ++++++ src/runtime/rpc/rpc_endpoint.cc | 39 +++++++++++++++ 11 files changed, 267 insertions(+), 32 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 59dc652aeb0b..49a86fc92d46 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -262,9 +262,49 @@ def max_thread_dimensions(self): """ return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) - def sync(self): - """Synchronize until jobs finished at the context.""" - check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) + def create_raw_stream(self): + """Create a new runtime stream at the context. + + User should free the stream after use. + + Returns + ------- + stream : TVMStreamHandle + The created runtime stream. + """ + stream = ctypes.c_void_p() + check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream))) + return stream + + def free_raw_stream(self, stream): + """Free a created stream handle. + + Parameters + ---------- + stream : TVMStreamHandle + The stream which should to be released. + """ + check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream)) + + def set_raw_stream(self, stream): + """Set a created stream handle. + + Parameters + ---------- + stream : TVMStreamHandle + The stream which should to be set to the device. + """ + check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream)) + + def sync(self, stream=None): + """Synchronize until jobs finished at the context. + + Parameters + ---------- + stream : TVMStreamHandle + Jobs in this stream should be finished. + """ + check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream)) def __eq__(self, other): return ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 83f1bcec7ebc..84dff157aa50 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1076,6 +1076,8 @@ def _timed_rpc_run( if error_no == 0: try: + stream = dev.create_raw_stream() + dev.set_raw_stream(stream) random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( random_fill @@ -1108,14 +1110,21 @@ def _timed_rpc_run( "task_inputs not fully matched, check if there's any unexpected error" ) dev.sync() + + # First run for check that the kernel is correct + func.entry_func(*args) + dev.sync() + costs = time_f(*args).results # clean up remote files remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") remote.remove("") + dev.free_raw_stream(stream) # pylint: disable=broad-except except Exception: + dev.free_raw_stream(stream) costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_traceback_info() diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index b9e8c2549fd5..d042cb406089 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -190,17 +190,11 @@ void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, s void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } -TVMStreamHandle DeviceAPI::CreateStream(Device dev) { - LOG(FATAL) << "Device does not support stream api."; - return nullptr; -} +TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } -void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) { - LOG(FATAL) << "Device does not support stream api."; -} +void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { - LOG(FATAL) << "Device does not support stream api."; } //-------------------------------------------------------- diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index c8044b49a8d0..172972cc57b9 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -129,6 +129,15 @@ int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream return 0; } +int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { + out = NULL; + return 0; +} + +int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { return 0; } + +int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { return 0; } + int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { return 0; } static TVMMutableFuncRegistry global_func_registry; diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 55f9022a6b96..9ebe04efbe4c 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -45,6 +45,32 @@ namespace tvm { namespace runtime { namespace metal { +/*! + * \brief Structure for error handling in queues + */ +class Stream { + public: + explicit Stream(id device) : error_happened_(false) { + queue_ = [device newCommandQueue]; + } + ~Stream() { [queue_ release]; } + id GetCommandBuffer() { + id cb = [queue_ commandBuffer]; + [cb addCompletedHandler:^(id buffer) { + if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus(); + }]; + return cb; + } + bool HasErrorHappened() { return error_happened_; } + + private: + void SetErrorStatus() { error_happened_ = true; } + // Queue + id queue_; + // Check if error happened in one previous run + bool error_happened_; +}; + /*! * \brief Process global Metal workspace. */ @@ -52,8 +78,6 @@ class MetalWorkspace final : public DeviceAPI { public: // the devices std::vector > devices; - // the queues - std::vector > queues; // Warp size constant std::vector warp_size; // Whether it is initialized. @@ -62,13 +86,6 @@ class MetalWorkspace final : public DeviceAPI { std::mutex mutex; // Destructor ~MetalWorkspace(); - // Get command queue for given device. - id GetCommandQueue(Device dev) { - ICHECK_EQ(dev.device_type, kDLMetal); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) - << "Invalid Metal device_id=" << dev.device_id; - return queues[dev.device_id]; - } // Get device for given device id GetDevice(Device dev) { ICHECK_EQ(dev.device_type, kDLMetal); @@ -84,9 +101,13 @@ class MetalWorkspace final : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(Device dev, void* ptr) final; + TVMStreamHandle CreateStream(Device dev) final; + void FreeStream(Device dev, TVMStreamHandle stream) final; void StreamSync(Device dev, TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + // get the global workspace static MetalWorkspace* Global(); @@ -94,6 +115,10 @@ class MetalWorkspace final : public DeviceAPI { void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) final; + + private: + // Pointers to default allocated streams + std::vector default_streams_; }; /*! \brief Thread local workspace */ @@ -101,6 +126,8 @@ class MetalThreadEntry { public: /*! \brief The current device */ Device device; + /*! \brief The current stream */ + std::vector stream; /*! \brief The shared buffer used for copy. */ std::vector > temp_buffer_; /*! \brief workspace pool */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index cf8520864e99..85b427509133 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -121,8 +121,8 @@ int GetWarpSize(id dev) { for (auto x : devices) { [x release]; } - for (auto x : queues) { - [x release]; + for (auto x : default_streams_) { + delete x; } } @@ -136,13 +136,17 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } @@ -183,16 +187,25 @@ int GetWarpSize(id dev) { } } +Stream* GetStream(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) + return static_cast(stream); + else + return MetalThreadEntry::ThreadLocal()->stream[device_id]; +} + void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { @autoreleasepool { this->Init(); - ICHECK(stream == nullptr); Device dev = dev_from; + Stream* s = GetStream(stream, dev.device_id); + if (s->HasErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; + } if (dev_from.device_type == kDLCPU) dev = dev_to; - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); @@ -249,17 +262,34 @@ int GetWarpSize(id dev) { } } +TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + Stream* stream = new Stream(devices[dev.device_id]); + return static_cast(stream); +} + +void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { + ICHECK(stream != nullptr); + Stream* s = static_cast(stream); + delete s; +} + void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { @autoreleasepool { - ICHECK(stream == nullptr); + Stream* s = GetStream(stream, dev.device_id); // commit an empty command buffer and wait until it completes. - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); [cb commit]; [cb waitUntilCompleted]; + if (s->HasErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned!"; + } } } +void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); +} + void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a8b01815bf68..e22caa21a81e 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -185,6 +185,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons @autoreleasepool { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; + auto stream = static_cast(t->stream[device_id]); + if (stream->HasErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } @@ -192,8 +194,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); - id queue = w_->GetCommandQueue(t->device); - id cb = [queue commandBuffer]; + id cb = stream->GetCommandBuffer(); id encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 732e1e49d4a4..1dfee70a20e2 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -297,10 +297,22 @@ class MinRPCServer { this->SyscallDevFreeData(values, tcodes, num_args); break; } + case RPCCode::kDevCreateStream: { + this->SyscallDevCreateStream(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeStream: { + this->SyscallDevFreeStream(values, tcodes, num_args); + break; + } case RPCCode::kDevStreamSync: { this->SyscallDevStreamSync(values, tcodes, num_args); break; } + case RPCCode::kDevSetStream: { + this->SyscallDevSetStream(values, tcodes, num_args); + break; + } case RPCCode::kCopyAmongRemote: { this->SyscallCopyAmongRemote(values, tcodes, num_args); break; @@ -444,6 +456,39 @@ class MinRPCServer { } } + void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kDLDevice); + + DLDevice dev = values[0].v_device; + void* handle; + + int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { MINRPC_CHECK(num_args == 2); MINRPC_CHECK(tcodes[0] == kDLDevice); @@ -461,6 +506,23 @@ class MinRPCServer { } } + void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { io_->Exit(static_cast(code)); } diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index e42508a73959..ace3e2bbb1b8 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -52,6 +52,9 @@ enum class RPCCode : int { kDevStreamSync, kCopyAmongRemote, kDevAllocDataWithScope, + kDevCreateStream, + kDevFreeStream, + kDevSetStream, }; /*! @@ -104,8 +107,14 @@ inline const char* RPCCodeToString(RPCCode code) { return "kDevAllocData"; case RPCCode::kDevFreeData: return "kDevFreeData"; + case RPCCode::kDevCreateStream: + return "kDevCreateStream"; + case RPCCode::kDevFreeStream: + return "kDevFreeStream"; case RPCCode::kDevStreamSync: return "kDevStreamSync"; + case RPCCode::kDevSetStream: + return "kDevSetStream"; case RPCCode::kCopyAmongRemote: return "kCopyAmongRemote"; case RPCCode::kDevAllocDataWithScope: diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 1d6fb85d9495..a2d1ac17ef7f 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -111,11 +111,26 @@ class RPCDeviceAPI final : public DeviceAPI { } } + TVMStreamHandle CreateStream(Device dev) { + auto remote_dev = RemoveRPCSessionMask(dev); + return GetSess(dev)->GetDeviceAPI(remote_dev)->CreateStream(remote_dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->FreeStream(remote_dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { auto remote_dev = RemoveRPCSessionMask(dev); GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream); + } + protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 236be8e56a70..28f93f641b4b 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -921,6 +921,24 @@ void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); } +void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + void* data = handler->GetDeviceAPI(dev)->CreateStream(dev); + *rv = data; +} + +void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->FreeStream(dev, stream); +} + +void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->SetStream(dev, stream); +} + void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { // Event handler sit at clean state at this point. switch (code) { @@ -946,9 +964,18 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevCreateStream: + SysCallHandler(RPCDevCreateStream); + break; + case RPCCode::kDevFreeStream: + SysCallHandler(RPCDevFreeStream); + break; case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break; + case RPCCode::kDevSetStream: + SysCallHandler(RPCDevSetStream); + break; case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; @@ -1034,10 +1061,22 @@ class RPCClientSession : public RPCSession, public DeviceAPI { endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); } + TVMStreamHandle CreateStream(Device dev) final { + return endpoint_->SysCallRemote(RPCCode::kDevCreateStream, dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeStream, dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { endpoint_->SysCallRemote(RPCCode::kDevStreamSync, dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream); + } + DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; } bool IsLocalSession() const final { return false; }