Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Replace uses of jax.devices("cpu") with jax.local_devices(backend="cpu") #27593

Merged
merged 1 commit into from
Dec 4, 2023

Conversation

hvaara
Copy link
Contributor

@hvaara hvaara commented Nov 19, 2023

An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.

This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.

This change is always safe (i.e., it should always preserve the previous behavior), but it may sometimes be unnecessary if code is never used in a multicontroller setting.

For a similar PR in diffusers see huggingface/diffusers#5864.

…nd="cpu")

An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.

This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.

This change is always safe (i.e., it should always preserve the previous behavior), but it may sometimes be unnecessary if code is never used in a multicontroller setting.

Co-authored-by: Peter Hawkins <phawkins@google.com>
@hvaara
Copy link
Contributor Author

hvaara commented Nov 19, 2023

/cc @sanchit-gandhi

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hvaara!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 😉

@ArthurZucker ArthurZucker merged commit bcd0a91 into huggingface:main Dec 4, 2023
3 checks passed
@hvaara hvaara deleted the jax-local-cpu branch December 4, 2023 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants