Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warp-JAX multi-GPU interoperability and added custom launch dimension #310

Merged
merged 5 commits into from
Sep 17, 2024

Conversation

mehdiataei
Copy link
Contributor

Category

  • New feature
  • Bugfix
  • Breaking change
  • Refactoring
  • Documentation
  • Other (please explain)

Description

  1. In warp/jax_experimental.py:
  • Modified the jax_kernel function to accept an optional launch_dims parameter.
  • Updated the abstract evaluation and lowering functions to use the provided launch_dims when available.
  • Changed jax.devices() to jax.local_devices() in the _get_jax_device function (this was a bug, as in multi-GPU settings Warp may select non-addressable devices).
  1. In warp/tests/test_jax.py:
    Added a new test case test_jax_kernel_launch_dims to verify the functionality of custom launch dimensions for both 1D and 2D kernels.
  2. In docs/modules/interoperability.rst:
  • Removed the limitation that output shapes must match launch dimensions given the new feature.
  • Added a new section on using shardmap for distributed multi-GPU computation with Warp and JAX.
  • Added a section on specifying launch dimensions for multi-GPU matrix operations.

Changelog

  • jax_kernel now accepts an optional launch_dims parameter. The launch dim is no longer limited the the shape of the first input.
  • Changed device selection from jax.devices() to jax.local_devices() to address multi-GPU launch issues.
  • Added tutorials on using JAX's shardmap for multi-GPU computations and specifying custom launch dimensions.

Before your PR is "Ready for review"

  • Do you agree to the terms under which contributions are accepted as described in Section 9 the Warp License?
  • Have you read the Contributor Guidelines?
  • Have you written any new necessary tests?
  • Have you added or updated any necessary documentation?
  • Have you added any files modified by compiling Warp and building the documentation to this PR (.e.g. stubs.py, functions.rst)?
  • Does your code pass ruff check and ruff format --check?

warp/tests/test_jax.py Show resolved Hide resolved
warp/jax_experimental.py Show resolved Hide resolved
docs/modules/interoperability.rst Outdated Show resolved Hide resolved
docs/modules/interoperability.rst Outdated Show resolved Hide resolved
Copy link
Contributor

@nvlukasz nvlukasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great and I really appreciate the documentation updates!

docs/modules/interoperability.rst Show resolved Hide resolved
docs/modules/interoperability.rst Show resolved Hide resolved
@shi-eric shi-eric merged commit 2b3a7c8 into NVIDIA:main Sep 17, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants