Skip to content

Commit

Permalink
[METAL] Fix issue with GPU fails (#7819)
Browse files Browse the repository at this point in the history
* [METAL] Fix issue with GPU fails

Added first run to auto scheduler. This run is necessary for checking
that the generated kernel is correct. When we just run time evaluator
with incorrect kernel then it is possible that our application on iOS
device will be added to ignore list because of big number of committed
incorrect kernels. One run before running auto scheduling helps us to
avoid this problem.

Added complete handlers to all command buffers in Metal runtime. It
helps to handle GPU errors and report about this error to the host
application.

In case when error happened, we have to create a new stream. Added
mechanism for error handling and streams creating from python interface.

* Try to fix QEMU build

* Apply comment

* Apply comments and fix build

* Apply comments and fix lint

* Fix CI
  • Loading branch information
echuraev authored Apr 16, 2021
1 parent 6aefc26 commit 899bc06
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 32 deletions.
46 changes: 43 additions & 3 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,49 @@ def max_thread_dimensions(self):
"""
return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8))

def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
def create_raw_stream(self):
"""Create a new runtime stream at the context.
User should free the stream after use.
Returns
-------
stream : TVMStreamHandle
The created runtime stream.
"""
stream = ctypes.c_void_p()
check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream)))
return stream

def free_raw_stream(self, stream):
"""Free a created stream handle.
Parameters
----------
stream : TVMStreamHandle
The stream which should to be released.
"""
check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream))

def set_raw_stream(self, stream):
"""Set a created stream handle.
Parameters
----------
stream : TVMStreamHandle
The stream which should to be set to the device.
"""
check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream))

def sync(self, stream=None):
"""Synchronize until jobs finished at the context.
Parameters
----------
stream : TVMStreamHandle
Jobs in this stream should be finished.
"""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream))

def __eq__(self, other):
return (
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,8 @@ def _timed_rpc_run(

if error_no == 0:
try:
stream = dev.create_raw_stream()
dev.set_raw_stream(stream)
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
Expand Down Expand Up @@ -1108,14 +1110,21 @@ def _timed_rpc_run(
"task_inputs not fully matched, check if there's any unexpected error"
)
dev.sync()

# First run for check that the kernel is correct
func.entry_func(*args)
dev.sync()

costs = time_f(*args).results

# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
remote.remove("")
dev.free_raw_stream(stream)
# pylint: disable=broad-except
except Exception:
dev.free_raw_stream(stream)
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.RUNTIME_DEVICE
error_msg = make_traceback_info()
Expand Down
10 changes: 2 additions & 8 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,11 @@ void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, s

void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); }

TVMStreamHandle DeviceAPI::CreateStream(Device dev) {
LOG(FATAL) << "Device does not support stream api.";
return nullptr;
}
TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }

void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}

void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}

//--------------------------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream
return 0;
}

int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
out = NULL;
return 0;
}

int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { return 0; }

int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { return 0; }

int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { return 0; }

static TVMMutableFuncRegistry global_func_registry;
Expand Down
45 changes: 36 additions & 9 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,39 @@
namespace tvm {
namespace runtime {
namespace metal {
/*!
* \brief Structure for error handling in queues
*/
class Stream {
public:
explicit Stream(id<MTLDevice> device) : error_happened_(false) {
queue_ = [device newCommandQueue];
}
~Stream() { [queue_ release]; }
id<MTLCommandBuffer> GetCommandBuffer() {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
}];
return cb;
}
bool HasErrorHappened() { return error_happened_; }

private:
void SetErrorStatus() { error_happened_ = true; }
// Queue
id<MTLCommandQueue> queue_;
// Check if error happened in one previous run
bool error_happened_;
};

/*!
* \brief Process global Metal workspace.
*/
class MetalWorkspace final : public DeviceAPI {
public:
// the devices
std::vector<id<MTLDevice> > devices;
// the queues
std::vector<id<MTLCommandQueue> > queues;
// Warp size constant
std::vector<int> warp_size;
// Whether it is initialized.
Expand All @@ -62,13 +86,6 @@ class MetalWorkspace final : public DeviceAPI {
std::mutex mutex;
// Destructor
~MetalWorkspace();
// Get command queue for given device.
id<MTLCommandQueue> GetCommandQueue(Device dev) {
ICHECK_EQ(dev.device_type, kDLMetal);
ICHECK(dev.device_id >= 0 && static_cast<size_t>(dev.device_id) < queues.size())
<< "Invalid Metal device_id=" << dev.device_id;
return queues[dev.device_id];
}
// Get device for given device
id<MTLDevice> GetDevice(Device dev) {
ICHECK_EQ(dev.device_type, kDLMetal);
Expand All @@ -84,23 +101,33 @@ class MetalWorkspace final : public DeviceAPI {
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(Device dev, void* ptr) final;
TVMStreamHandle CreateStream(Device dev) final;
void FreeStream(Device dev, TVMStreamHandle stream) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;

// get the global workspace
static MetalWorkspace* Global();

protected:
void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size,
Device dev_from, Device dev_to, DLDataType type_hint,
TVMStreamHandle stream) final;

private:
// Pointers to default allocated streams
std::vector<Stream*> default_streams_;
};

/*! \brief Thread local workspace */
class MetalThreadEntry {
public:
/*! \brief The current device */
Device device;
/*! \brief The current stream */
std::vector<Stream*> stream;
/*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer> > temp_buffer_;
/*! \brief workspace pool */
Expand Down
50 changes: 40 additions & 10 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ int GetWarpSize(id<MTLDevice> dev) {
for (auto x : devices) {
[x release];
}
for (auto x : queues) {
[x release];
for (auto x : default_streams_) {
delete x;
}
}

Expand All @@ -136,13 +136,17 @@ int GetWarpSize(id<MTLDevice> dev) {
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back(d);
queues.push_back([d newCommandQueue]);
Stream* stream = new Stream(d);
MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
default_streams_.push_back(stream);
#else
NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
queues.push_back([d newCommandQueue]);
Stream* stream = new Stream(d);
MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
default_streams_.push_back(stream);
LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
warp_size.push_back(GetWarpSize(d));
}
Expand Down Expand Up @@ -183,16 +187,25 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

Stream* GetStream(TVMStreamHandle stream, int device_id) {
if (stream != nullptr)
return static_cast<Stream*>(stream);
else
return MetalThreadEntry::ThreadLocal()->stream[device_id];
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
@autoreleasepool {
this->Init();
ICHECK(stream == nullptr);
Device dev = dev_from;
Stream* s = GetStream(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
}
if (dev_from.device_type == kDLCPU) dev = dev_to;
id<MTLCommandQueue> queue = GetCommandQueue(dev);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
int from_dev_type = static_cast<int>(dev_from.device_type);
int to_dev_type = static_cast<int>(dev_to.device_type);

Expand Down Expand Up @@ -249,17 +262,34 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
Stream* stream = new Stream(devices[dev.device_id]);
return static_cast<TVMStreamHandle>(stream);
}

void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
ICHECK(stream != nullptr);
Stream* s = static_cast<Stream*>(stream);
delete s;
}

void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
@autoreleasepool {
ICHECK(stream == nullptr);
Stream* s = GetStream(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(dev);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
[cb waitUntilCompleted];
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned!";
}
}
}

void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
}

void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
5 changes: 3 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,16 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
@autoreleasepool {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->device.device_id;
auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
if (stream->HasErrorHappened()) return;
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
id<MTLCommandQueue> queue = w_->GetCommandQueue(t->device);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
Expand Down
Loading

0 comments on commit 899bc06

Please sign in to comment.