How did you solve the problem of jax.pmap hanging? #61
-
I have an embarrassingly parallel problem and I seem to be running into the problem you brought up a couple of years ago in jax-ml/jax#5065, which doesn't seem to have been resolved fully in Jax. I was wondering if you found some useful heuristic workarounds to prevent this and if you wouldn't mind sharing them. It seems you must have, since there seems to be plenty of use of pmap in your code. Did you have to settle for the (generally inadvisable) method of wrapping pmaps in jits, or did you find something better? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi @halgorthim you can use https://github.com/Joshuaalbert/jaxns/blob/master/jaxns/internals/maps.py#L101. It behaves just like from jaxns.internals.maps import chunked_pmap
def embarassingly_parallel_func(*args, **kwargs):
pass
parallel_func = chunked_pmap(embarassingly_parallel_func, chunksize=...)
results = parallel_func(*args, **kwargs)
#if args[0] is a pytree then you need to specify the batch_size parameter too.
parallel_func = chunked_pmap(embarassingly_parallel_func, chunksize=..., batch_size=...)
Behind the scenes it will pad and reshape arguments and then stick the items in a number of queues that each get executed in parallel with Hope this helps! |
Beta Was this translation helpful? Give feedback.
Hi @halgorthim you can use https://github.com/Joshuaalbert/jaxns/blob/master/jaxns/internals/maps.py#L101.
It behaves just like
pmap
with a few extra parameters. Use it like this:chunck_size
is slightly misnamed, and means how many parallel workers to use. Much be less that number of devices.batch_size
is…