diff --git a/torchbenchmark/operators/sum/kernels.py b/torchbenchmark/operators/sum/kernels.py index 7b5a945e88..2eb9a15b1b 100644 --- a/torchbenchmark/operators/sum/kernels.py +++ b/torchbenchmark/operators/sum/kernels.py @@ -14,18 +14,26 @@ def triton_sum_kernel_scalar( block_start = pid * BLOCK_SIZE_M # offsets have shape equal to input shape - offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block + offsets = block_start + tl.arange( + 0, BLOCK_SIZE_M + ) # create 1D vector (input shape) ranging from beginning to end of this program's block # mask has shape equal to input shape mask = offsets < M # mask out offsets that are out of bounds for input # loaded pointers have shape equal to input shape - x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape + x = tl.load( + input_ptr + offsets, mask=mask, other=mask + ) # load input, where the loaded pointers are in the desired input shape output = tl.sum(x) # output_offsets have shape equal to output shape - output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,)) + output_offsets = tl.arange( + 0, 1 + ) # create offsets for scalar output pointer (output shape == (1,)) # stored pointers have shape equal to output shape - tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape + tl.store( + output_ptr + output_offsets, output + ) # store output, where the stored pointers are in the desired output shape diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py index 475e7b4b3a..7707453142 100644 --- a/torchbenchmark/operators/sum/operator.py +++ b/torchbenchmark/operators/sum/operator.py @@ -14,12 +14,28 @@ from .kernels import triton_sum_kernel_scalar +def parse_op_args(args: List[str]): + parser = argparse.ArgumentParser() + parser.add_argument( + "--reduce-dim", + type=int, + nargs="*", + default=None, + help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1", + ) + return parser.parse_args(args) + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] - def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): + def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): super().__init__(mode=mode, device=device, extra_args=extra_args) + args = parse_op_args(self.extra_args) + self.reduce_dim = ( + args.reduce_dim if args.reduce_dim else None + ) # for 2D case, guaranteed to be a list with 1 integer self.sizes = range(1, 17) @register_benchmark() @@ -27,13 +43,18 @@ def triton_sum(self, x: torch.Tensor): x_1d = x.view(-1) M = x_1d.shape[0] grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),) - BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2 + BLOCK_SIZE_M = triton.next_power_of_2( + M + ) # race condition in cases where BLOCK_SIZE < n_elements^2 def _inner(): output = torch.zeros(1, device=x.device, dtype=x.dtype) triton_sum_kernel_scalar[grid]( - x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M, + x_1d, + output, + M=M, + BLOCK_SIZE_M=BLOCK_SIZE_M, ) return output @@ -53,24 +74,28 @@ def get_x_vals(self) -> List[int]: x_vals.extend([2**n for n in self.sizes]) x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0]) - + return x_vals def get_input_iter(self) -> Generator: # reduce to a scalar value for size in self.get_x_vals(): # 1D matrix input_1d = torch.randn(size, device=self.device, dtype=self.dtype) - yield (input_1d, ) + yield (input_1d,) for size in self.get_x_vals(): # 2D matrix if size < pow(2, 8): # ensure we don't exceed floating point limitations - input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype) - yield (input_2d, ) + input_2d = torch.randn( + (size, size), device=self.device, dtype=self.dtype + ) + yield (input_2d,) for size in self.get_x_vals(): # 3D matrix if size < pow(2, 4): # ensure we don't exceed floating point limitations - input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype) - yield (input_2d, ) + input_2d = torch.randn( + (size, size, size), device=self.device, dtype=self.dtype + ) + yield (input_2d,) def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() @@ -78,7 +103,9 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: return torch.allclose(output, baseline_output, atol=1e-4) @register_metric(skip_baseline=True) - def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): + def input_dims( + self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics + ): return [ex.dim() for ex in example_inputs] @register_metric()