Skip to content

Commit

Permalink
[pallas_mgpu] Add a test for emit_pipeline with wgmma.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723012611
  • Loading branch information
petebu authored and Google-ML-Automation committed Feb 4, 2025
1 parent 124e123 commit c7d535d
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,74 @@ def kernel_body(x_smem, o_smem):
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)

def test_emit_pipeline_with_wgmma(self):
self.skip_unless_sm90a()

m, n, k = 256, 256, 256
dtype = jnp.float16
key = jax.random.key(42)
x = jax.random.uniform(key, shape=(m, k), dtype=dtype)
y = jax.random.uniform(key, shape=(k, n), dtype=dtype)

swizzle = 128
swizzle_elems = swizzle // jnp.dtype(x.dtype).itemsize

tile_m = 64
tile_n = 64
tile_k = swizzle_elems

grid_m = m // tile_m
grid_n = n // tile_n
grid_k = k // tile_k

def kernel(a_gmem, b_gmem, c_smem, acc_reg):
def pipeline_body(a_smem, b_smem):
plgpu.wgmma(acc_reg, a_smem, b_smem)

plgpu.emit_pipeline(
pipeline_body,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda i: (0, i),
transforms=(
plgpu.TilingTransform((64, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
),
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda i: (i, 0),
transforms=(
plgpu.TilingTransform((swizzle_elems, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
),
),
],
grid=(grid_k,),
max_concurrent_steps=2,
delay_release=1,
)(a_gmem, b_gmem)

c_smem[...] = acc_reg[...]

@jax.jit
def matmul(a: jax.Array, b: jax.Array) -> jax.Array:
return pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=plgpu.GMEM),
pl.BlockSpec(memory_space=plgpu.GMEM),
],
out_specs=pl.BlockSpec((tile_m, tile_n), lambda m, n: (m, n)),
grid=(grid_m, grid_n),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), dtype)],
)(a, b)

res = matmul(x, y)
np.testing.assert_allclose(res, x @ y, rtol=0.4)


class WarpSpecializedPipelineTest(PallasTest):

Expand Down

0 comments on commit c7d535d

Please sign in to comment.