Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax multithreading pitfalls #1661

Open
smartalecH opened this issue Jul 8, 2021 · 8 comments
Open

jax multithreading pitfalls #1661

smartalecH opened this issue Jul 8, 2021 · 8 comments

Comments

@smartalecH
Copy link
Collaborator

We recently pivoted from autograd to jax for our adjoint autodifferentiation needs. While jax supports gradients for more functions than autograd it comes with a price.

jax relies on a sophisticated backend (XLA) to implement all of its heavy computations. jax also has some sophisticated machinery to translate the code (and even compile, if you request a JIT) for forward and backward computations. To keep things efficient, jax will spawn multiple threads throughout the pipeline.

Normally this isn't a problem -- multithreading is great! But jax is a bit cavalier with how it spawns threads. For example, a single process can spawn upwards of 100 threads just to initialize jax (depending on your version, hardware, imports, etc.)! As you start using its libraries in your code (e.g. vmaps, loops, ffts, etc.) that thread count tends to increase significantly.

For small meep simulations (e.g. 1-24 cores) this isn't an issue. But as you start to increase the number of procs, you linearly scale the number of threads that are spawned, often exceeding what's allowed on your system (you can check using ulimit -u). Consequently, your simulation will die because the XLA engine cannot spawn any additional pthreads. This is especially relevant when running job arrays for a parameter sweep.

jax does respect some CPU affinity settings (see here for a great discussion) but it will still spawn multiple threads per core. Even when you explicitly tell jax to only use 1 thread for computations like convolutions, it will still use multiple threads for its many other tasks in the pipeline! Furthermore, you have to make sure your job manager and your MPI build are co-configured to support this kind of task scheduling (otherwise you can get weird, inefficient behavior, like binding everything to the same core).

While I've successfully run thousands of adjoint optimizations (in tandem) on one particular cluster with no thread issues, I'm unable to run more than 16 cores at a time on another cluster! I imagine many users are going to experience something similar.

I think this further motivates the heterogenous MPI-openMP paradigm we started, where we simply spawn one process per node. This will force the thread count down to a bare minimum.

@stevengj
Copy link
Collaborator

stevengj commented Jul 8, 2021

Can you comment on the linked jax issue, since they seem to think that the problem is resolved?

@ianwilliamson
Copy link
Contributor

Are you seeing this with the JAX wrapper that was recently introduced in #1569?

The wrapper itself should be pretty lightweight in terms of the JAX calculations that are performed, though it's notable that simply importing JAX may be causing some issues here. I would be interested to learn more about how you're using vmap and jit in this context.

Given that the full python interpreter and its state get replicated across each MPI process, I guess this isn't too surprising though. The scipy and nlopt optimizers also have non-negligible compute and memory loads. Have we seen any issues with these getting replicated across all MPI workers?

@smartalecH
Copy link
Collaborator Author

Are you seeing this with the JAX wrapper that was recently introduced in #1569?

It's not unique to what was introduced in #1569. Rather, anytime we use jax in a multi-process environment, we are going to see the behavior I described above.

The scipy and nlopt optimizers also have non-negligible compute and memory loads. Have we seen any issues with these getting replicated across all MPI workers?

I've never had issues with nlopt's mma algorithm. And that's with large 3D designs (including 3D degrees of freedom) across several nodes (each with ~36 procs). I know @mochen4 saw some issues when he pivoted to design regions with 3D DoF, but there are ways to mitigate that.

But as soon as you start using jax, you limit the number of processes you can launch on a single node.

@ianwilliamson
Copy link
Contributor

And that's with large 3D designs (including 3D degrees of freedom) across several nodes (each with ~36 procs).

Should the total memory footprint scale linearly with the number of processes? Just curious if you've ever measured this.

The issue with JAX and multi threading under MPI is unfortunate. I agree that it's a good motivation for thread-level parallelism.

@stevengj
Copy link
Collaborator

stevengj commented Jul 30, 2021

Should the total memory footprint scale linearly with the number of processes? Just curious if you've ever measured this.

Yes, because we are duplicating the DOFs and the optimization computations on every process. (The optimization internally requires storage O(DOFs * (1+#constraints)). For 3d DOFs (e.g. for 3d printing), as in @mochen4's work, that starts to be a substantial amount of memory for process, which is why @mochen4 is currently using the multi-threading branch so that we only have one process per multi-core node.

@stevengj
Copy link
Collaborator

stevengj commented Nov 3, 2021

As Alec mentioned today, this affects you even if you aren't using MPI at all — even if you just want to launch multiple independent jobs (e.g. for a parameter sweep) on a single multi-core node, the JAX instances in the independent processes will fight with each other over the cores.

To me, it's just not acceptable for any library to assume that it has exclusive use of all the cores on a machine, because it makes the library non-composable with any other form of parallelism. I don't know if @ianwilliamson has any inside track with the Jax developers on how to limit Jax's parallelism here (e.g. to limit its threadpool size)? The linked issue already has a way to limit XLA's threadpool, but it seems that's not sufficient.

What is the downside of just sticking with autograd here?

@ianwilliamson
Copy link
Contributor

@smartalecH Can we try to document here which approaches for controlling the number of threads were tried and failed? There are a few different issues linked in your OP, each with several different proposals and approaches described.

For example, I see one possible solution proposed here:
jax-ml/jax#2685 (comment)

I also see another solution proposed here:
jax-ml/jax#743 (comment)

@smartalecH
Copy link
Collaborator Author

smartalecH commented Nov 5, 2021

Can we try to document here which approaches for controlling the number of threads were tried and failed?

Yes, good idea. Here is a list of things I tried:

Flags unique to Jax/LA

Limit the number of XLA threads (actually doesn't truly limit the number of threads, especially during JIT...)

os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")

Limit the number of threads for the other LA libraries:

os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

Other

I tried various cpu binding flags for multiple flavors of MPI (openMPI, mpich, etc.). Although in hindsight this shouldn't completely solve the issue, as it still crashed with parallel jobs launched by slurm.

I also tried various binding options with slurm (and tried different process engines too) and still saw issues.

FWIW, my current cluster using moab/torque doesn't seem to have issues with this.

I realize this isn't incredibly comprehensive -- unfortunately, I no longer have access to all the methods I tried (and probably couldn't share them anyway).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants