Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS as default if available #311

Open
jkoudys opened this issue Dec 23, 2022 · 10 comments
Open

Add MPS as default if available #311

jkoudys opened this issue Dec 23, 2022 · 10 comments

Comments

@jkoudys
Copy link

jkoudys commented Dec 23, 2022

Can we show some love for the Mac M1 people out there? MPS doesn't seem any harder to choose if available than CUDA, and tch-rs seems to include it in their Device enum.

I'm okay to PR myself, if anyone can suggest the right approach. Should it simply choose MPS the same way as cuda_if_available does, and default it if available? Should we start by checking for cuda, then checking for mps, and only then defaulting to CPU?

@dimfeld
Copy link

dimfeld commented Dec 28, 2022

Were you able to get MPS to actually work? I haven't been successful yet. I tried altering the device in the model config to MPS, but I get the error Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS. When I load torch in Python it says MPS is supported, so I think I have everything installed properly.

My best guess right now is that this is related to pytorch/pytorch#88820, where JIT models created in a certain way won't load on MPS. I'm pretty new to all this though, so not completely sure.

@jkoudys
Copy link
Author

jkoudys commented Dec 28, 2022 via email

@dimfeld
Copy link

dimfeld commented Dec 28, 2022

Ok, I actually did get it working! It's a little haphazard though. The trick is to load the VarStore using the CPU device, and then migrate it to the GPU device. From my other research, it appears that saving that migrated VarStore to disk would then allow it to be used directly with MPS, but I haven't tried that yet.

On my M1 Pro this runs about 2-3 times as fast as with CPU/AMX.

I'm using the Sentence Embedding pipeline. Here's the relevant change there:

  let transformer =
    SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
  var_store.load(transformer_weights_resource.get_local_path()?)?;
+ var_store.set_device(tch::Device::Mps);

@chrisvander
Copy link

Is there any way to access var_store after initializing the default model? Using the zero-shot classification pipeline.

@dimfeld
Copy link

dimfeld commented Jan 20, 2023

Currently, I don't think so. I ended up just copying the pipeline code into my own project and modifying it for my purposes.

@guillaume-be
Copy link
Owner

Apologies I don't have a MPS device at hand for testing - it seems the issue is that creating the model passing Device::Mps fails, but setting it to Device::Cpu for loading weights and then changing the VarStore device works.
It would make sense to raise an issue with the upstream tch-rs as to my understanding it should be possible to work with Mps the same way as with Cuda .

Would accessing the var_store of pipelines via .get_mut_var_store(&mut self) -> &mut VarStore interface help for your usecase for the time being?

@chrisvander
Copy link

That would help, yes! Any way to access to the underlying var_store without having to rewrite the whole initialization would work.

@guillaume-be
Copy link
Owner

@dimfeld @chrisvander I have opened a fix on the upstream library (LaurentMazare/tch-rs#623) - if you have the time it would be great if you could perform some testing and see if this addresses the issue.

@chrisvander
Copy link

Hmm... I did make it to a run with MPS, but it seems to be ignoring the PYTORCH_ENABLE_MPS_FALLBACK=1 environment variable and is failing with TorchError "sparse_grad not supported in MPS yet". Should fallback to CPU there. Any tips?

let config = ZeroShotClassificationConfig {
    device: Device::Mps,
    ..Default::default()
};
let model = ZeroShotClassificationModel::new(config).unwrap();

@guillaume-be
Copy link
Owner

Hello @chrisvander ,

Apologies for the late response. I am pushing some changes that set the sparse_grad parameter to false across the library to improve device compatibility (#404) that should solve the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants