Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Apples' Metal Performance Shaders (MPS) in pytorch #2037

Merged
merged 3 commits into from
Sep 28, 2022

Conversation

demq
Copy link
Contributor

@demq demq commented Sep 27, 2022

Description

Pytorch 1.12 supports Apple’s Metal Performance Shaders (MPS) for accelerated training/inference. A simple test on HuggingFace models shows 2x-3x improvements on the inference latency over the CPU cores on an Apple M1Pro with 8 CPU and 14 GPU cores:

--------------------------------------------------------------------------------
Model:              roberta-hf
Device:             cpu()
Samples:            34
Model Input Length: 384
Metrics:
	Total: 140.206 ± 5.880      msec
	Inference: 138.794 ± 5.713      msec
	Preprocess: 0.441 ± 0.553      msec
	Postprocess: 0.824 ± 0.567      msec
--------------------------------------------------------------------------------
Model:              roberta-hf
Device:             mps(-1)
Samples:            34
Model Input Length: 384
Metrics:
	Total: 65.500 ± 9.166      msec
	Inference: 41.235 ± 6.198      msec
	Preprocess: 1.059 ± 0.338      msec
	Postprocess: 23.206 ± 6.794      msec

--------------------------------------------------------------------------------

This improvement has been requested by me in the DJL help slack channel, as well as more recently in #2018 .

Brief description of what this PR is about

  • Adds a new "MPS" device type and maps to the corresponding pytorch device type.

Some caveats:

  • The "MPS" mode only works if torch::jit::load(path, device, map); is called with device= torch::nullopt as in torch the model gets deserialized in 'legacy' mode, which only support CPU and GPU devices. The model get desiralized on the CPU and then can be converted to MPS. For this, in traced torch model's directory in the file "serving.properties" the option.mapLocation=false should be set, which triggers this implementation already present in DJL native module.
    Alternatively, the the following check can be added the pytorch-native module:
    if (jmap_location) {
    module = torch::jit::load(path, device, map);
    module.eval();
    } else {
if (jmap_location && (device.is_cpu() || device.is_cuda())) {

@@ -36,6 +36,7 @@ public final class Device {

private static final Device CPU = new Device(Type.CPU, -1);
private static final Device GPU = Device.of(Type.GPU, 0);
private static final Device MPS = Device.of(Type.MPS, -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS is PyTorch specific, for the time being, it's better keep it PyTorch only. We can add to Device class when it become standard.

Copy link
Contributor Author

@demq demq Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too sure how to avoid this part "cleanly" without having to rewrite a bunch of code in pytorche's JniUtils.java, as most of the functions there use a Device object to infer the PtDeviceType:

public static PtNDArray createZerosNdArray(
PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) {
int layoutVal = layoutMapper(fmt, device);
return new PtNDArray(
manager,
PyTorchLibrary.LIB.torchZeros(
shape.getShape(),
dType.ordinal(),
layoutVal,
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()},
false));
}

One way is to change the PtDeviceType:: toDeviceType(Device device) to somehow return the code for MPS device on osx-aarm64 systems, but I don't know how to do that:

public static int toDeviceType(Device device) {
String deviceType = device.getDeviceType();
if (Device.Type.CPU.equals(deviceType)) {
return 0;
} else if (Device.Type.GPU.equals(deviceType)) {
return 1;
} else {
throw new IllegalArgumentException("Unsupported device: " + device.toString());
}
}

/** Contains device type string constants. */
public interface Type {
String CPU = "cpu";
String GPU = "gpu";
String MPS = "mps";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need make changes in Device class. Let's keep it private in PyTorch for now

Copy link
Contributor Author

@demq demq Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this might create some false expectations form users that Apple MPS is supported for all the engines, but the current solution means that apple silicone users don't need to rewrite any of their existing code. Perhaps adding a documentation that MPS will only work for pytorch?

The users can simply specify "mps" as their device type in their model Criteria when running on Mac M1/M2 and get the acceleration working. Otherwise the developers would need to add pytorch and OSX-specific code to support "mps", making it less likely to be used. The device type, on the other hand, is usually just a system setting / cli argument to ensure let's say the GPU acceleration can be switched on/off.

@codecov-commenter
Copy link

Codecov Report

Base: 72.08% // Head: 69.53% // Decreases project coverage by -2.55% ⚠️

Coverage data is based on head (b9df963) compared to base (bb5073f).
Patch coverage: 68.32% of modified lines in pull request are covered.

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #2037      +/-   ##
============================================
- Coverage     72.08%   69.53%   -2.56%     
- Complexity     5126     5953     +827     
============================================
  Files           473      597     +124     
  Lines         21970    26498    +4528     
  Branches       2351     2880     +529     
============================================
+ Hits          15838    18426    +2588     
- Misses         4925     6687    +1762     
- Partials       1207     1385     +178     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
...rc/main/java/ai/djl/modality/cv/output/Joints.java 71.42% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...main/java/ai/djl/modality/cv/output/Rectangle.java 72.41% <0.00%> (ø)
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <0.00%> (-5.24%) ⬇️
...odality/cv/translator/BigGANTranslatorFactory.java 33.33% <0.00%> (+8.33%) ⬆️
.../cv/translator/InstanceSegmentationTranslator.java 0.00% <0.00%> (-86.59%) ⬇️
...nslator/InstanceSegmentationTranslatorFactory.java 7.14% <0.00%> (-11.04%) ⬇️
.../cv/translator/SemanticSegmentationTranslator.java 0.00% <0.00%> (ø)
... and 512 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@frankfliu frankfliu merged commit 1d0bc77 into deepjavalibrary:master Sep 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants