Skip to content

How to force random numbers to be generated on CPU? #9691

Answered by ayaka14732
ayaka14732 asked this question in Q&A
Discussion options

You must be logged in to vote

This issue can be resolved by using the default device context manager introduced in #9118:

import jax

device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
    a = jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,))
    print(a.device())  # TFRT_CPU_0

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

ayaka14732
Feb 25, 2022
Collaborator Author

You must be logged in to vote
2 replies
@YouJiacheng
Comment options

@ayaka14732
Comment options

ayaka14732 Feb 25, 2022
Collaborator Author

Comment options

ayaka14732
Sep 15, 2022
Collaborator Author

You must be logged in to vote
0 replies
Answer selected by ayaka14732
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants