How to force random numbers to be generated on CPU? #9691
-
I would like to generate random numbers with >>> a = jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,))
>>> a.device()
TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) The results are generated on the default device (i.e. TPU). However, I would like the results to be generated on CPU. One way I can think of is to use >>> a = jax.jit(lambda: jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,)), backend='cpu')()
>>> a.device()
CpuDevice(id=0) However, this method is slower due to the jit compilation. What is the best way of doing this? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
what about generate = jax.jit(lambda key: jax.random.poisson(key, 3, shape=(1000,)), backend='cpu')
key = jax.random.PRNGKey(0)
for _ in range(100):
key, subkey = jax.random.split(key)
generate(subkey) only need one compilation. |
Beta Was this translation helpful? Give feedback.
-
I am expecting something like this: >>> key = jax.random.PRNGKey(0, backend='cpu')
>>> key, subkey = jax.random.split(key)
>>> a = jax.random.poisson(subkey, 3, shape=(1000,), backend='cpu') |
Beta Was this translation helpful? Give feedback.
-
This issue can be resolved by using the default device context manager introduced in #9118: import jax
device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
a = jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,))
print(a.device()) # TFRT_CPU_0 |
Beta Was this translation helpful? Give feedback.
This issue can be resolved by using the default device context manager introduced in #9118: