From fd4900ceeee805ab92f9ed228b248df3f60e2d97 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 28 May 2024 10:51:48 -0700 Subject: [PATCH] [Pallas] Set a better tiling for gmm (#7119) Summary: Set a better tiling to make gmm run faster according to http://shortn/_e9u7jnMlYK. Test Plan: python test/test_gmm.py --- torch_xla/experimental/custom_kernel.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 73f889988416..2f528bafae5f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -709,8 +709,12 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor, return res -def gmm(lhs: torch.Tensor, rhs: torch.Tensor, - group_sizes: torch.Tensor) -> torch.Tensor: +def gmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + tiling: tuple[int, int, int] = (512, 512, 512) +) -> torch.Tensor: """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. Args: @@ -718,6 +722,7 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor, rhs: A 3d, jnp.ndarray with shape [num_groups, k, n]. group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. preferred_element_type: jnp.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. Returns: A 2d, jnp.ndarray with shape [m, n]. @@ -727,17 +732,24 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor, jax_import_guard() from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm - payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes) + m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2] + tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n) + payload, _ = trace_pallas( + gmm, + lhs, + rhs, + group_sizes, + static_argnames=["tiling"], + tiling=(tm, tk, tn)) - m, n = lhs.shape[0], rhs.shape[2] # Create the metadata we need for computation. # TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor. # That means we need to materialize this input in order to use this gmm # kernel, and that will introduce graph breaks in the computation. group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata( group_sizes=group_sizes, - m=lhs.shape[0], - tm=128 # TODO (alanwaketan): Tune this later. + m=m, + tm=tm, ) group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla")