From 9f2ab7498aa36bd0d2fb6375e40fde7078ab9dbc Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Fri, 7 Jun 2024 10:31:36 -0700 Subject: [PATCH] Add command line argument parsing for reduction dimensions in Triton sum kernel (#2284) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2284 Add argument parsing for the command line in order to pass in dimension(s) across which the kernel reduces and enable more rigorous testing of different versions of the sum kernel, referencing [torchbenchmark/operators/fb/flash_attention/operator.py](https://www.internalfb.com/code/fbsource/[864a578ce44afdba619d50a352c8ca3b783e05ef]/fbcode/pytorch/benchmark/torchbenchmark/operators/fb/flash_attention/operator.py?lines=84). Inherit the `__init__` function from the parent class `BenchmarkOperator` in order to facilitate command line argument parsing. Change `dim` type to `list` to avoid type issues resulting from `tl.constexpr`. Modify equality checks in kernel and operator to satisfy type requirements for `dim`. Reviewed By: xuzhao9 Differential Revision: D58212366 fbshipit-source-id: 5c88a7c3e8bf2f37408c6c5e3d302b7e9a473bd4 --- torchbenchmark/operators/sum/kernels.py | 16 ++++++-- torchbenchmark/operators/sum/operator.py | 47 +++++++++++++++++++----- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/torchbenchmark/operators/sum/kernels.py b/torchbenchmark/operators/sum/kernels.py index 7b5a945e88..2eb9a15b1b 100644 --- a/torchbenchmark/operators/sum/kernels.py +++ b/torchbenchmark/operators/sum/kernels.py @@ -14,18 +14,26 @@ def triton_sum_kernel_scalar( block_start = pid * BLOCK_SIZE_M # offsets have shape equal to input shape - offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block + offsets = block_start + tl.arange( + 0, BLOCK_SIZE_M + ) # create 1D vector (input shape) ranging from beginning to end of this program's block # mask has shape equal to input shape mask = offsets < M # mask out offsets that are out of bounds for input # loaded pointers have shape equal to input shape - x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape + x = tl.load( + input_ptr + offsets, mask=mask, other=mask + ) # load input, where the loaded pointers are in the desired input shape output = tl.sum(x) # output_offsets have shape equal to output shape - output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,)) + output_offsets = tl.arange( + 0, 1 + ) # create offsets for scalar output pointer (output shape == (1,)) # stored pointers have shape equal to output shape - tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape + tl.store( + output_ptr + output_offsets, output + ) # store output, where the stored pointers are in the desired output shape diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py index 475e7b4b3a..7707453142 100644 --- a/torchbenchmark/operators/sum/operator.py +++ b/torchbenchmark/operators/sum/operator.py @@ -14,12 +14,28 @@ from .kernels import triton_sum_kernel_scalar +def parse_op_args(args: List[str]): + parser = argparse.ArgumentParser() + parser.add_argument( + "--reduce-dim", + type=int, + nargs="*", + default=None, + help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1", + ) + return parser.parse_args(args) + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): + def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): super().__init__(mode=mode, device=device, extra_args=extra_args) + args = parse_op_args(self.extra_args) + self.reduce_dim = ( + args.reduce_dim if args.reduce_dim else None + ) # for 2D case, guaranteed to be a list with 1 integer self.sizes = range(1, 17) @register_benchmark() @@ -27,13 +43,18 @@ def triton_sum(self, x: torch.Tensor): x_1d = x.view(-1) M = x_1d.shape[0] grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),) - BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2 + BLOCK_SIZE_M = triton.next_power_of_2( + M + ) # race condition in cases where BLOCK_SIZE < n_elements^2 def _inner(): output = torch.zeros(1, device=x.device, dtype=x.dtype) triton_sum_kernel_scalar[grid]( - x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M, + x_1d, + output, + M=M, + BLOCK_SIZE_M=BLOCK_SIZE_M, ) return output @@ -53,24 +74,28 @@ def get_x_vals(self) -> List[int]: x_vals.extend([2**n for n in self.sizes]) x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0]) - + return x_vals def get_input_iter(self) -> Generator: # reduce to a scalar value for size in self.get_x_vals(): # 1D matrix input_1d = torch.randn(size, device=self.device, dtype=self.dtype) - yield (input_1d, ) + yield (input_1d,) for size in self.get_x_vals(): # 2D matrix if size < pow(2, 8): # ensure we don't exceed floating point limitations - input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype) - yield (input_2d, ) + input_2d = torch.randn( + (size, size), device=self.device, dtype=self.dtype + ) + yield (input_2d,) for size in self.get_x_vals(): # 3D matrix if size < pow(2, 4): # ensure we don't exceed floating point limitations - input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype) - yield (input_2d, ) + input_2d = torch.randn( + (size, size, size), device=self.device, dtype=self.dtype + ) + yield (input_2d,) def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() @@ -78,7 +103,9 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: return torch.allclose(output, baseline_output, atol=1e-4) @register_metric(skip_baseline=True) - def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): + def input_dims( + self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics + ): return [ex.dim() for ex in example_inputs] @register_metric()