Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warp-JAX multi-GPU interoperability and added custom launch dimension #310

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 227 additions & 1 deletion docs/modules/interoperability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ Since this is an experimental feature, there are some limitations:
- Kernel launch dimensions are inferred from the shape of the first argument.
- Input arguments are followed by output arguments in the Warp kernel definition.
- There must be at least one input argument and at least one output argument.
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
- All arrays must be contiguous.
- Only the CUDA backend is supported.

Expand Down Expand Up @@ -462,6 +461,233 @@ Here is an example of an operation with three inputs and two outputs::
print(x)
print(y)

Using shardmap for distributed computation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Warp can be used in conjunction with JAX's `shard_map <https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html>`_ to perform distributed multi-GPU computations.

To achieve this, the JAX distributed environment must be initialized (see `Distributed Arrays and Automatic Parallelization <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_ for more details):

.. code-block:: python

import jax
jax.distributed.initialize()

This initialization must be called at the beginning of your program, before any other JAX operations.

Here's an example of how to use `shard_map` with a Warp kernel:

.. code-block:: python
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

import warp as wp
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental import jax_kernel
import numpy as np

# Initialize JAX distributed environment
jax.distributed.initialize()
num_gpus = jax.device_count()

def print_on_process_0(*args, **kwargs):
if jax.process_index() == 0:
print(*args, **kwargs)

print_on_process_0(f"Running on {num_gpus} GPU(s)")

@wp.kernel
def multiply_by_two_kernel(
a_in: wp.array(dtype=wp.float32),
a_out: wp.array(dtype=wp.float32),
):
index = wp.tid()
a_out[index] = a_in[index] * 2.0

jax_warp_multiply = jax_kernel(multiply_by_two_kernel)

def warp_multiply(x):
result = jax_warp_multiply(x)
return result

# a_in here is the full sharded array with shape (M,)
# The output will also be a sharded array with shape (M,)
def warp_distributed_operator(a_in):
def _sharded_operator(a_in):
# Inside the sharded operator, a_in is a local shard on each device
# If we have N devices and input size M, each shard has shape (M/N,)

# warp_multiply applies the Warp kernel to the local shard
result = warp_multiply(a_in)[0]

# result has the same shape as the input shard (M/N,)
return result

# shard_map distributes the computation across devices
return shard_map(
_sharded_operator,
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
in_specs=(P("x"),), # Input is sharded along the 'x' axis
out_specs=P("x"), # Output is also sharded along the 'x' axis
check_rep=False,
)(a_in)

print_on_process_0("Test distributed multiplication using JAX + Warp")

devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
sharding_spec = jax.sharding.NamedSharding(mesh, P("x"))

input_size = num_gpus * 5 # 5 elements per device
single_device_arrays = jnp.arange(input_size, dtype=jnp.float32)

# Define the shape of the input array based on the total input size
shape = (input_size,)

# Create a list of arrays by distributing the single_device_arrays across the available devices
# Each device will receive a portion of the input data
arrays = [
jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device
for d, index in sharding_spec.addressable_devices_indices_map(shape).items()
]

# Combine the individual device arrays into a single sharded array
sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays)

# sharded_array has shape (input_size,) but is distributed across devices
print_on_process_0(f"Input array: {allgather(sharded_array)}")

# warp_result has the same shape and sharding as sharded_array
warp_result = warp_distributed_operator(sharded_array)

# allgather collects results from all devices, resulting in a full array of shape (input_size,)
print_on_process_0("Warp Output:", allgather(warp_result))

In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output.

This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.

To run this program on multiple GPUs, you must have OpenMPI installed. You can consult the `OpenMPI installation guide <https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html>`_ for instructions on how to install it. Once OpenMPI is installed, you can use `mpirun` with the following command:

.. code-block:: bash

mpirun -np <NUM_OF_GPUS> python <filename>.py


Specifying launch dimensions for matrix operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In some cases, particularly for matrix operations, it's necessary to specify the launch dimensions for Warp kernels. This is because the default behavior of inferring dimensions from the first argument may not always be suitable for matrix operations. Here's an example of a distributed matrix multiplication using Warp and JAX:

.. code-block:: python
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

import warp as wp
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental import jax_kernel
import numpy as np

jax.distributed.initialize()
num_gpus = jax.device_count()

def print_on_process_0(*args, **kwargs):
if jax.process_index() == 0:
print(*args, **kwargs)

print_on_process_0(f"Running on {num_gpus} GPU(s)")

@wp.kernel
def matmul_kernel(
a: wp.array2d(dtype=wp.float32),
b: wp.array2d(dtype=wp.float32),
c: wp.array2d(dtype=wp.float32),
):
# a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N)
i, j = wp.tid()
M = a.shape[0] # M/num_gpus
K = a.shape[1] # K
N = b.shape[1] # N
if i < M and j < N:
s = wp.float32(0.0)
for k in range(K):
s += a[i, k] * b[k, j]
c[i, j] = s

# Specify launch dimensions based on the number of GPUs
def create_jax_warp_matmul(M, N):
# M: total rows, N: total columns
block_size_m = M // num_gpus # Rows per GPU
block_size_n = N # All columns
return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n))

def warp_distributed_matmul(a, b):
# a: (M, K) sharded across GPUs, b: (K, N) replicated
M, K = a.shape
_, N = b.shape
jax_warp_matmul = create_jax_warp_matmul(M, N)

def _sharded_operator(a_shard, b):
# a_shard: (M/num_gpus, K), b: (K, N)
return jax_warp_matmul(a_shard, b)[0] # Result: (M/num_gpus, N)

return shard_map(
_sharded_operator,
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
in_specs=(P("x", None), P(None, None)), # a sharded in first dim, b replicated
out_specs=P("x", None), # Output sharded in first dim
check_rep=False,
)(a, b)

print_on_process_0("Test distributed matrix multiplication using JAX + Warp")

# Define matrix dimensions
M = 8 * num_gpus # Scale M with the number of devices
K, N = 4, 6

# Create input matrices
a = jnp.arange(M * K, dtype=jnp.float32).reshape(M, K) # Shape: (M, K)
b = jnp.arange(K * N, dtype=jnp.float32).reshape(K, N) # Shape: (K, N)

devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
sharding_spec_a = jax.sharding.NamedSharding(mesh, P("x", None))
sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None, None))

# Shard matrix A and replicate matrix B
sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device
replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices

print_on_process_0(f"Input matrix A:\n{allgather(sharded_a)}") # Shape: (M, K)
print_on_process_0(f"Input matrix B:\n{allgather(replicated_b)}") # Shape: (K, N)

warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device
print_on_process_0("Warp Output:")
# Use allgather to collect results from all devices
print_on_process_0(allgather(warp_result)) # Shape: (M, N)

jax_result = jnp.matmul(a, b) # Shape: (M, N)
print_on_process_0("JAX Output:")
print_on_process_0(jax_result)

expected_shape = (M, N)
print_on_process_0(f"Expected shape: {expected_shape}")
print_on_process_0(f"Warp output shape: {warp_result.shape}") # Should be (M/num_gpus, N) on each device
print_on_process_0(f"JAX output shape: {jax_result.shape}") # Should be (M, N)

allclose = jnp.allclose(allgather(warp_result), jax_result, atol=1e-5)
print_on_process_0(f"Allclose: {allclose}")

In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction.

Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.

Note that this is a naive implementation of matrix multiplication for the sake of this illustration, and there are many optimizations that can be made to improve performance.

.. _DLPack:

DLPack
Expand Down
43 changes: 28 additions & 15 deletions warp/jax_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@
_registered_kernel_to_id = {}


def jax_kernel(wp_kernel):
def jax_kernel(wp_kernel, launch_dims=None):
"""Create a Jax primitive from a Warp kernel.

NOTE: This is an experimental feature under development.

Args:
wp_kernel: The Warp kernel to be wrapped.
launch_dims: Optional. Specify the kernel launch dimensions. If None,
dimensions are inferred from the shape of the first argument.
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved
This option when set will specify the output dimensions.

Current limitations:
- All kernel arguments must be arrays.
- Kernel launch dimensions are inferred from the shape of the first argument.
- If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument.
- Input arguments are followed by output arguments in the Warp kernel definition.
- There must be at least one input argument and at least one output argument.
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
- All arrays must be contiguous.
- Only the CUDA backend is supported.
"""
Expand All @@ -47,7 +52,7 @@ def jax_kernel(wp_kernel):
id = _registered_kernel_to_id[wp_kernel]

def bind(*args):
return _jax_warp_p.bind(*args, kernel=id)
return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)

return bind

Expand Down Expand Up @@ -106,7 +111,7 @@ def _get_jax_device():
device = jax.config.jax_default_device
# if default device is not set, use first device
if device is None:
device = jax.devices()[0]
device = jax.local_devices()[0]
return device


Expand Down Expand Up @@ -223,12 +228,17 @@ def base_type_is_compatible(warp_type, jax_ir_type):
raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}")

# Abstract evaluation.
def jax_warp_abstract(*args, kernel=None):
def jax_warp_abstract(*args, kernel=None, launch_dims=None):
wp_kernel = _registered_kernels[kernel]
# All the extra arguments to the warp kernel are outputs.
warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]]
# TODO. Let's just use the first input dimension to infer the output's dimensions.
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))

if launch_dims is None:
# Use the first input dimension to infer the output's dimensions if launch_dims is not provided
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
else:
dims = launch_dims

jax_outputs = []
for o in warp_outputs:
shape = list(dims) + list(get_vecmat_shape(o))
Expand Down Expand Up @@ -260,7 +270,7 @@ def jax_warp_abstract(*args, kernel=None):
def default_layout(shape):
return range(len(shape) - 1, -1, -1)

def warp_call_lowering(ctx, *args, kernel=None):
def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None):
if not kernel:
raise Exception("Unknown kernel id " + str(kernel))
wp_kernel = _registered_kernels[kernel]
Expand All @@ -272,12 +282,15 @@ def warp_call_lowering(ctx, *args, kernel=None):
if not module.load(device):
raise Exception("Could not load kernel on device")

# Infer dimensions from the first input.
warp_arg0 = wp_kernel.adj.args[0]
actual_shape0 = ir.RankedTensorType(args[0].type).shape
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)

if launch_dims is None:
# Infer dimensions from the first input.
warp_arg0 = wp_kernel.adj.args[0]
actual_shape0 = ir.RankedTensorType(args[0].type).shape
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
else:
dims = launch_dims
warp_dims = launch_dims
# Figure out the types and shapes of the input arrays.
arg_strings = []
operand_layouts = []
Expand Down
Loading