Skip to content

Commit

Permalink
Don't specialize the executable for the current device (#3)
Browse files Browse the repository at this point in the history
Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com>
  • Loading branch information
georgepaw and DenisVieriu97 committed Dec 12, 2023
1 parent a645df9 commit 021f6ea
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 18 deletions.
5 changes: 1 addition & 4 deletions backends/apple/mps/runtime/MPSCompiler.mm
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ void printLoadedGraph(MPSGraphExecutable* executable) {
size_t num_bytes) {
ExirMPSGraphPackage* exirMPSGraphPackage = (ExirMPSGraphPackage*)buffer_pointer;
NSData *new_manifest_plist_data = [NSData dataWithBytes:exirMPSGraphPackage->data length:exirMPSGraphPackage->model_0_offset];
NSData *new_model_0_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset length:exirMPSGraphPackage->model_1_offset - exirMPSGraphPackage->model_0_offset];
NSData *new_model_1_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_1_offset length:exirMPSGraphPackage->total_bytes - sizeof(ExirMPSGraphPackage) - exirMPSGraphPackage->model_1_offset];
NSData *new_model_0_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset length:exirMPSGraphPackage->total_bytes - sizeof(ExirMPSGraphPackage) - exirMPSGraphPackage->model_0_offset];

NSError* error = nil;
NSString* packageName = [NSString stringWithUTF8String:(
Expand All @@ -52,14 +51,12 @@ void printLoadedGraph(MPSGraphExecutable* executable) {

NSString* manifestFileStr = [NSString stringWithFormat:@"%@/manifest.plist", dataFileNSStr];
NSString* model0FileStr = [NSString stringWithFormat:@"%@/model_0.mpsgraph", dataFileNSStr];
NSString* model1FileStr = [NSString stringWithFormat:@"%@/model_1.mpsgraph", dataFileNSStr];

NSFileManager *fileManager= [NSFileManager defaultManager];
[fileManager createDirectoryAtPath:dataFileNSStr withIntermediateDirectories:NO attributes:nil error:&error];

[new_manifest_plist_data writeToFile:manifestFileStr options:NSDataWritingAtomic error:&error];
[new_model_0_data writeToFile:model0FileStr options:NSDataWritingAtomic error:&error];
[new_model_1_data writeToFile:model1FileStr options:NSDataWritingAtomic error:&error];

NSURL *bundleURL = [NSURL fileURLWithPath:dataFileNSStr];
MPSGraphCompilationDescriptor *compilationDescriptor = [MPSGraphCompilationDescriptor new];
Expand Down
3 changes: 0 additions & 3 deletions backends/apple/mps/utils/MPSGraphInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ class MPSGraphModule {
std::vector<MPSGraphTensor*> outputTensors_;
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphExecutable* executable_;

id<MTLDevice> device_;
id<MTLCommandQueue> commandQueue_;
};

} // namespace mps
12 changes: 2 additions & 10 deletions backends/apple/mps/utils/MPSGraphInterface.mm
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
"MPS Executorch backend is supported only from macOS 14.0 and above.");

mpsGraph = [MPSGraph new];
device_ = MTLCreateSystemDefaultDevice();
commandQueue_ = [device_ newCommandQueue];
}

MPSGraphModule::~MPSGraphModule() {
Expand Down Expand Up @@ -95,7 +93,7 @@
[targetTensors addObject:outputTensor];
});

MPSGraphExecutable *exec = [mpsGraph compileWithDevice:[MPSGraphDevice deviceWithMTLDevice:device_]
MPSGraphExecutable *exec = [mpsGraph compileWithDevice:nil
feeds:feeds
targetTensors:targetTensors
targetOperations:nil
Expand All @@ -111,7 +109,6 @@

std::string name = "mpsgraphmodule_" + std::to_string(arc4random_uniform(INT_MAX));
std::string mpsgraphpackagePath = dataFolder + name + ".mpsgraphpackage";

NSString *mpsgraphpackageFileStr = [NSString stringWithUTF8String:mpsgraphpackagePath.c_str()];
NSURL *bundleURL = [NSURL fileURLWithPath:mpsgraphpackageFileStr];

Expand All @@ -122,28 +119,23 @@

NSString* mpsgraphpackage_manifest_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/manifest.plist").c_str()];
NSString* mpsgraphpackage_model_0_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/model_0.mpsgraph").c_str()];
NSString* mpsgraphpackage_model_1_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/model_1.mpsgraph").c_str()];

NSURL* manifestPlistURL = [NSURL fileURLWithPath:mpsgraphpackage_manifest_file];
NSURL* model0URL = [NSURL fileURLWithPath:mpsgraphpackage_model_0_file];
NSURL* model1URL = [NSURL fileURLWithPath:mpsgraphpackage_model_1_file];

NSData* manifest_plist_data = [NSData dataWithContentsOfURL:manifestPlistURL];
NSData* model_0_data = [NSData dataWithContentsOfURL:model0URL];
NSData* model_1_data = [NSData dataWithContentsOfURL:model1URL];

int64_t total_package_size = sizeof(ExirMPSGraphPackage) + [manifest_plist_data length] + [model_0_data length] + [model_1_data length];
int64_t total_package_size = sizeof(ExirMPSGraphPackage) + [manifest_plist_data length] + [model_0_data length];
ExirMPSGraphPackage *exirMPSGraphPackage = (ExirMPSGraphPackage*)malloc(total_package_size);
assert(exirMPSGraphPackage != nil);

exirMPSGraphPackage->manifest_plist_offset = 0;
exirMPSGraphPackage->model_0_offset = [manifest_plist_data length];
exirMPSGraphPackage->model_1_offset = exirMPSGraphPackage->model_0_offset + [model_0_data length];
exirMPSGraphPackage->total_bytes = total_package_size;

memcpy(exirMPSGraphPackage->data, [manifest_plist_data bytes], [manifest_plist_data length]);
memcpy(exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset, [model_0_data bytes], [model_0_data length]);
memcpy(exirMPSGraphPackage->data + exirMPSGraphPackage->model_1_offset, [model_1_data bytes], [model_1_data length]);

std::vector<uint8_t> data((uint8_t*)exirMPSGraphPackage, (uint8_t*)exirMPSGraphPackage + total_package_size);
free(exirMPSGraphPackage);
Expand Down
1 change: 0 additions & 1 deletion backends/apple/mps/utils/MPSGraphPackageExport.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
struct ExirMPSGraphPackage {
int64_t manifest_plist_offset;
int64_t model_0_offset;
int64_t model_1_offset;
int64_t total_bytes;
uint8_t data[];
};

0 comments on commit 021f6ea

Please sign in to comment.