Skip to content

Commit

Permalink
[Metal] Enable Debug Label (#17059)
Browse files Browse the repository at this point in the history
This PR adds label to MTLCommandBuffer, to enable instruments profiling.
  • Loading branch information
Hzfengsy authored Jun 7, 2024
1 parent 2f800df commit 1d761da
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,11 @@ class Stream {
public:
explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
~Stream() { [queue_ release]; }
id<MTLCommandBuffer> GetCommandBuffer(bool attach_error_callback = true) {
id<MTLCommandBuffer> GetCommandBuffer(std::string label = "", bool attach_error_callback = true) {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
if (!label.empty()) {
cb.label = [NSString stringWithUTF8String:label.c_str()];
}
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) {
ICHECK(buffer.error != nil);
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
return;
case kL2CacheSizeBytes:
return;
case kAvailableGlobalMemory:
return;
case kTotalGlobalMemory: {
*rv = static_cast<int64_t>([devices[dev.device_id] recommendedMaxWorkingSetSize]);
return;
Expand Down Expand Up @@ -225,7 +227,7 @@ int GetWarpSize(id<MTLDevice> dev) {
if (s->HasErrorHappened()) {
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
id<MTLCommandBuffer> cb = s->GetCommandBuffer(/*label=*/"TVMCopyDataFromTo");
int from_dev_type = static_cast<int>(dev_from.device_type);
int to_dev_type = static_cast<int>(dev_to.device_type);

Expand Down Expand Up @@ -298,7 +300,7 @@ int GetWarpSize(id<MTLDevice> dev) {
AUTORELEASEPOOL {
Stream* s = CastStreamOrGetDefault(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
id<MTLCommandBuffer> cb = s->GetCommandBuffer(/*label=*/"TVMStreamSync");
[cb commit];
[cb waitUntilCompleted];
if (s->HasErrorHappened()) {
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
// attach error message directly in this functio
id<MTLCommandBuffer> cb = stream->GetCommandBuffer(/* attach_error_callback= */ false);
id<MTLCommandBuffer> cb = stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_,
/*attach_error_callback=*/false);
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
Expand Down

0 comments on commit 1d761da

Please sign in to comment.