From 1d761dac458f4083284732b23da7e1155bd0d6bf Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 7 Jun 2024 21:03:49 +0800 Subject: [PATCH] [Metal] Enable Debug Label (#17059) This PR adds label to MTLCommandBuffer, to enable instruments profiling. --- src/runtime/metal/metal_common.h | 5 ++++- src/runtime/metal/metal_device_api.mm | 6 ++++-- src/runtime/metal/metal_module.mm | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index e5339e636612..d68dd0b2cd3b 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -109,8 +109,11 @@ class Stream { public: explicit Stream(id device) { queue_ = [device newCommandQueue]; } ~Stream() { [queue_ release]; } - id GetCommandBuffer(bool attach_error_callback = true) { + id GetCommandBuffer(std::string label = "", bool attach_error_callback = true) { id cb = [queue_ commandBuffer]; + if (!label.empty()) { + cb.label = [NSString stringWithUTF8String:label.c_str()]; + } [cb addCompletedHandler:^(id buffer) { if (buffer.status == MTLCommandBufferStatusError) { ICHECK(buffer.error != nil); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 42dd249630ff..f2e8c4ab0b75 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -89,6 +89,8 @@ return; case kL2CacheSizeBytes: return; + case kAvailableGlobalMemory: + return; case kTotalGlobalMemory: { *rv = static_cast([devices[dev.device_id] recommendedMaxWorkingSetSize]); return; @@ -225,7 +227,7 @@ int GetWarpSize(id dev) { if (s->HasErrorHappened()) { LOG(FATAL) << "GPUError: " << s->ErrorDescription(); } - id cb = s->GetCommandBuffer(); + id cb = s->GetCommandBuffer(/*label=*/"TVMCopyDataFromTo"); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); @@ -298,7 +300,7 @@ int GetWarpSize(id dev) { AUTORELEASEPOOL { Stream* s = CastStreamOrGetDefault(stream, dev.device_id); // commit an empty command buffer and wait until it completes. - id cb = s->GetCommandBuffer(); + id cb = s->GetCommandBuffer(/*label=*/"TVMStreamSync"); [cb commit]; [cb waitUntilCompleted]; if (s->HasErrorHappened()) { diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 16956ed6118b..b33827423180 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -206,7 +206,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); // attach error message directly in this functio - id cb = stream->GetCommandBuffer(/* attach_error_callback= */ false); + id cb = stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_, + /*attach_error_callback=*/false); id encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) {