-
Notifications
You must be signed in to change notification settings - Fork 638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax multithreading pitfalls #1661
Comments
Can you comment on the linked jax issue, since they seem to think that the problem is resolved? |
Are you seeing this with the JAX wrapper that was recently introduced in #1569? The wrapper itself should be pretty lightweight in terms of the JAX calculations that are performed, though it's notable that simply importing JAX may be causing some issues here. I would be interested to learn more about how you're using Given that the full python interpreter and its state get replicated across each MPI process, I guess this isn't too surprising though. The scipy and nlopt optimizers also have non-negligible compute and memory loads. Have we seen any issues with these getting replicated across all MPI workers? |
It's not unique to what was introduced in #1569. Rather, anytime we use jax in a multi-process environment, we are going to see the behavior I described above.
I've never had issues with nlopt's mma algorithm. And that's with large 3D designs (including 3D degrees of freedom) across several nodes (each with ~36 procs). I know @mochen4 saw some issues when he pivoted to design regions with 3D DoF, but there are ways to mitigate that. But as soon as you start using jax, you limit the number of processes you can launch on a single node. |
Should the total memory footprint scale linearly with the number of processes? Just curious if you've ever measured this. The issue with JAX and multi threading under MPI is unfortunate. I agree that it's a good motivation for thread-level parallelism. |
Yes, because we are duplicating the DOFs and the optimization computations on every process. (The optimization internally requires storage |
As Alec mentioned today, this affects you even if you aren't using MPI at all — even if you just want to launch multiple independent jobs (e.g. for a parameter sweep) on a single multi-core node, the JAX instances in the independent processes will fight with each other over the cores. To me, it's just not acceptable for any library to assume that it has exclusive use of all the cores on a machine, because it makes the library non-composable with any other form of parallelism. I don't know if @ianwilliamson has any inside track with the Jax developers on how to limit Jax's parallelism here (e.g. to limit its threadpool size)? The linked issue already has a way to limit XLA's threadpool, but it seems that's not sufficient. What is the downside of just sticking with |
@smartalecH Can we try to document here which approaches for controlling the number of threads were tried and failed? There are a few different issues linked in your OP, each with several different proposals and approaches described. For example, I see one possible solution proposed here: I also see another solution proposed here: |
Yes, good idea. Here is a list of things I tried: Flags unique to Jax/LALimit the number of XLA threads (actually doesn't truly limit the number of threads, especially during JIT...)
Limit the number of threads for the other LA libraries:
OtherI tried various cpu binding flags for multiple flavors of MPI (openMPI, mpich, etc.). Although in hindsight this shouldn't completely solve the issue, as it still crashed with parallel jobs launched by slurm. I also tried various binding options with slurm (and tried different process engines too) and still saw issues. FWIW, my current cluster using moab/torque doesn't seem to have issues with this. I realize this isn't incredibly comprehensive -- unfortunately, I no longer have access to all the methods I tried (and probably couldn't share them anyway). |
We recently pivoted from
autograd
tojax
for our adjoint autodifferentiation needs. Whilejax
supports gradients for more functions thanautograd
it comes with a price.jax
relies on a sophisticated backend (XLA) to implement all of its heavy computations.jax
also has some sophisticated machinery to translate the code (and even compile, if you request a JIT) for forward and backward computations. To keep things efficient,jax
will spawn multiple threads throughout the pipeline.Normally this isn't a problem -- multithreading is great! But
jax
is a bit cavalier with how it spawns threads. For example, a single process can spawn upwards of 100 threads just to initializejax
(depending on your version, hardware, imports, etc.)! As you start using its libraries in your code (e.g.vmaps
, loops, ffts, etc.) that thread count tends to increase significantly.For small meep simulations (e.g. 1-24 cores) this isn't an issue. But as you start to increase the number of procs, you linearly scale the number of threads that are spawned, often exceeding what's allowed on your system (you can check using
ulimit -u
). Consequently, your simulation will die because the XLA engine cannot spawn any additional pthreads. This is especially relevant when running job arrays for a parameter sweep.jax
does respect some CPU affinity settings (see here for a great discussion) but it will still spawn multiple threads per core. Even when you explicitly telljax
to only use 1 thread for computations like convolutions, it will still use multiple threads for its many other tasks in the pipeline! Furthermore, you have to make sure your job manager and your MPI build are co-configured to support this kind of task scheduling (otherwise you can get weird, inefficient behavior, like binding everything to the same core).While I've successfully run thousands of adjoint optimizations (in tandem) on one particular cluster with no thread issues, I'm unable to run more than 16 cores at a time on another cluster! I imagine many users are going to experience something similar.
I think this further motivates the heterogenous MPI-openMP paradigm we started, where we simply spawn one process per node. This will force the thread count down to a bare minimum.
The text was updated successfully, but these errors were encountered: