Skip to content

Commit

Permalink
[Pallas] Set a better tiling for gmm (#7119)
Browse files Browse the repository at this point in the history
Summary:
Set a better tiling to make gmm run faster according to http://shortn/_e9u7jnMlYK.

Test Plan:
python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 28, 2024
1 parent 65b5ace commit fd4900c
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,15 +709,20 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
return res


def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> torch.Tensor:
def gmm(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: tuple[int, int, int] = (512, 512, 512)
) -> torch.Tensor:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Args:
lhs: A 2d, jnp.ndarray with shape [m, k].
rhs: A 3d, jnp.ndarray with shape [num_groups, k, n].
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
Returns:
A 2d, jnp.ndarray with shape [m, n].
Expand All @@ -727,17 +732,24 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
jax_import_guard()
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm

payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes)
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2]
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
payload, _ = trace_pallas(
gmm,
lhs,
rhs,
group_sizes,
static_argnames=["tiling"],
tiling=(tm, tk, tn))

m, n = lhs.shape[0], rhs.shape[2]
# Create the metadata we need for computation.
# TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor.
# That means we need to materialize this input in order to use this gmm
# kernel, and that will introduce graph breaks in the computation.
group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata(
group_sizes=group_sizes,
m=lhs.shape[0],
tm=128 # TODO (alanwaketan): Tune this later.
m=m,
tm=tm,
)
group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla")

Expand Down

0 comments on commit fd4900c

Please sign in to comment.