Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Make gmm functional #7117

Merged
merged 10 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram
from torch_xla import runtime as xr
from torch_xla._internal import tpu
Expand Down Expand Up @@ -98,6 +99,8 @@ def _init_test_cases(self):

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
met.clear_all()

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
Expand All @@ -110,20 +113,24 @@ def test_gmm(self):
lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(
m=m, num_groups=num_groups) # This is a cpu tensor!!!!!!!
m=m, num_groups=num_groups).to('xla')
out = gmm(lhs, rhs, group_sizes)

ref_out = self._reference_gmm(lhs.cpu().float().numpy(),
rhs.cpu().float().numpy(),
group_sizes.numpy())
group_sizes.cpu().numpy())

atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype)
np.testing.assert_allclose(
ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol)

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_make_group_metadata(self):
from jax.experimental.pallas.ops.tpu.megablox.gmm import make_group_metadata as jax_make_group_metadata
met.clear_all()

test_grids = [
{
Expand Down Expand Up @@ -173,15 +180,19 @@ def test_make_group_metadata(self):
)

torch_meta = _make_group_metadata(
group_sizes=torch.tensor(test_grid['group_sizes']),
group_sizes=torch.tensor(test_grid['group_sizes']).to("xla"),
m=test_grid['m'],
tm=test_grid['tm'],
)

for i in range(len(jax_meta)):
self.assertTrue(
torch.all(torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i]))
self.assertEqual(jax_num_tiles, torch_meta[-1].item())
torch.all(
torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i].cpu()))
self.assertEqual(jax_num_tiles, torch_meta[-1].cpu().item())

# Make sure _make_group_metadata doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

def test_histogram(self):
test_grids = [
Expand Down Expand Up @@ -215,13 +226,13 @@ def test_histogram(self):
max=test_grid['max'],
)

chart, _ = _histogram(
chart = _histogram(
torch.tensor(test_grid['input'], dtype=torch.int32).to("xla"),
min=test_grid['min'],
max=test_grid['max'],
)

self.assertTrue(torch.all(torch_chart == chart.cpu()))
self.assertTrue(torch.all(torch_chart == chart.cpu()))

def test_histogram_raise(self):
with self.assertRaisesRegex(AssertionError,
Expand All @@ -232,7 +243,8 @@ def test_histogram_raise(self):
max=5,
)

with self.assertRaisesRegex(AssertionError, "min must be less than max."):
with self.assertRaisesRegex(AssertionError,
"min must be less than or equal to max."):
_histogram(
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.int32),
min=4,
Expand Down
41 changes: 16 additions & 25 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,18 +501,18 @@ def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
Compute the histogram of a int32 tensor. The bin edges are defined by the min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min < max, "min must be less than max."
assert min <= max, "min must be less than or equal to max."

def searchsorted(sorted_sequence: torch.Tensor,
values_to_search: torch.Tensor) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)

bin_edges = torch.linspace(
min, max, max - min + 1, dtype=input.dtype).to(input.device)
return searchsorted(bin_edges, input), bin_edges
return searchsorted(bin_edges, input)


# This can only be ran in cpu now as repeat_interleave is not lowered to xla.
# Refence: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L78
def _make_group_metadata(
*,
group_sizes: torch.Tensor,
Expand Down Expand Up @@ -544,6 +544,7 @@ def _make_group_metadata(
num_tiles: The number of m-dimension tiles to execute including overlapping
executions. And don't confuse this with m_tiles which is m // tm.
"""
device = group_sizes.device
num_groups = group_sizes.shape[0]

# Calculate the offset of each group, starting at zero. This metadata is
Expand All @@ -555,7 +556,8 @@ def _make_group_metadata(
#
# The row at which group 'i' starts is group_offsets[i].
group_ends = torch.cumsum(group_sizes, dim=0, dtype=torch.int32)
group_offsets = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends])
group_offsets = torch.cat(
[torch.zeros(1, dtype=torch.int32).to(device), group_ends])

# Assign a group id to each grid index.
#
Expand All @@ -571,7 +573,8 @@ def _make_group_metadata(
rounded_group_ends = ((group_ends + tm - 1) // tm * tm).to(torch.int32)

# (2) Round the group_starts down to the nearest multiple of 'tm'.
group_starts = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends[:-1]])
group_starts = torch.cat(
[torch.zeros(1, dtype=torch.int32).to(device), group_ends[:-1]])
rounded_group_starts = group_starts // tm * tm

# (3) Calculate the number of rows in each group.
Expand Down Expand Up @@ -613,14 +616,9 @@ def _make_group_metadata(
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
# such that we only execute the necessary number of tiles.
tiles_m = _calculate_num_tiles(m, tm)
# TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times.
group_ids = torch.repeat_interleave(
torch.arange(num_groups, dtype=torch.int32),
group_tiles,
)
group_ids = torch.nn.functional.pad(
group_ids, (0, tiles_m + num_groups - 1 - group_ids.shape[0]),
value=num_groups - 1)
group_ids = repeat_with_fixed_output_size(
torch.arange(num_groups, dtype=torch.int32).to(device), group_tiles,
tiles_m + num_groups - 1)

# Assign an m-dimension tile id to each grid index.
#
Expand Down Expand Up @@ -652,20 +650,13 @@ def _make_group_metadata(
partial_tile_ids = torch.where(partial_tile_mask, tiles_m,
group_offsets[:-1] // tm)

tile_visits = (
torch.histc(
partial_tile_ids.float(), bins=tiles_m, min=0, max=tiles_m - 1) + 1)
tile_visits = (_histogram(partial_tile_ids, min=0, max=tiles_m - 1) + 1)

# Create the m-dimension tile ids for each grid index based on the visit
# counts for each tile.
# TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times.
m_tile_ids = torch.repeat_interleave(
torch.arange(tiles_m, dtype=torch.int32),
tile_visits.type(torch.int32),
)
m_tile_ids = torch.nn.functional.pad(
m_tile_ids, (0, tiles_m + num_groups - 1 - m_tile_ids.shape[0]),
value=tiles_m - 1)
m_tile_ids = repeat_with_fixed_output_size(
torch.arange(tiles_m, dtype=torch.int32).to(device), tile_visits,
tiles_m + num_groups - 1)

num_tiles = group_tiles.sum(dtype=torch.int32)
return group_offsets, group_ids, m_tile_ids, num_tiles
Expand Down Expand Up @@ -706,7 +697,7 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
# tensor([2, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
block_split_indicators = torch.zeros(
total_repeat_length, dtype=torch.int64, device=device)
block_split_indicators.scatter_add_(0, valid_indices,
block_split_indicators.scatter_add_(0, valid_indices.to(torch.int64),
torch.ones_like(block_split_indicators))
# out_of_bound indices also scatter to index 0, need to offset them
block_split_indicators[0] -= out_of_bound_count
Expand Down
Loading