Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add command line argument parsing for reduction dimensions in Triton sum kernel #2284

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@ def triton_sum_kernel_scalar(

block_start = pid * BLOCK_SIZE_M
# offsets have shape equal to input shape
offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block
offsets = block_start + tl.arange(
0, BLOCK_SIZE_M
) # create 1D vector (input shape) ranging from beginning to end of this program's block

# mask has shape equal to input shape
mask = offsets < M # mask out offsets that are out of bounds for input

# loaded pointers have shape equal to input shape
x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape
x = tl.load(
input_ptr + offsets, mask=mask, other=mask
) # load input, where the loaded pointers are in the desired input shape

output = tl.sum(x)

# output_offsets have shape equal to output shape
output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,))
output_offsets = tl.arange(
0, 1
) # create offsets for scalar output pointer (output shape == (1,))

# stored pointers have shape equal to output shape
tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape
tl.store(
output_ptr + output_offsets, output
) # store output, where the stored pointers are in the desired output shape
47 changes: 37 additions & 10 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,47 @@
from .kernels import triton_sum_kernel_scalar


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--reduce-dim",
type=int,
nargs="*",
default=None,
help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1",
)
return parser.parse_args(args)


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
args = parse_op_args(self.extra_args)
self.reduce_dim = (
args.reduce_dim if args.reduce_dim else None
) # for 2D case, guaranteed to be a list with 1 integer
self.sizes = range(1, 17)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
x_1d = x.view(-1)
M = x_1d.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2

def _inner():
output = torch.zeros(1, device=x.device, dtype=x.dtype)

triton_sum_kernel_scalar[grid](
x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M,
x_1d,
output,
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)

return output
Expand All @@ -53,32 +74,38 @@ def get_x_vals(self) -> List[int]:

x_vals.extend([2**n for n in self.sizes])
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0])

return x_vals

def get_input_iter(self) -> Generator:
# reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d, )
yield (input_1d,)

for size in self.get_x_vals(): # 2D matrix
if size < pow(2, 8): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )
input_2d = torch.randn(
(size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

for size in self.get_x_vals(): # 3D matrix
if size < pow(2, 4): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )
input_2d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=1e-4)

@register_metric(skip_baseline=True)
def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
def input_dims(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return [ex.dim() for ex in example_inputs]

@register_metric()
Expand Down
Loading