Skip to content

Commit

Permalink
Add Triton Gemm into TorchBench
Browse files Browse the repository at this point in the history
Summary: As the title

Reviewed By: xuzhao9

Differential Revision: D55346077

fbshipit-source-id: ae863ac020a56bfc0486a0624a63c462d999ce76
  • Loading branch information
sijiac authored and facebook-github-bot committed Apr 2, 2024
1 parent da2a8ba commit fbb768d
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 11 deletions.
179 changes: 179 additions & 0 deletions torchbenchmark/operators/gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import argparse
import os
import statistics
from typing import Callable, Generator, List, Optional

import numpy
import torch
import triton

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .triton_matmul import matmul as triton_matmul

BUILDIN_SHAPES = [
(256, 256, 256),
(384, 384, 384),
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
(896, 896, 896),
(1024, 1024, 1024),
(1152, 1152, 1152),
(1280, 1280, 1280),
(1408, 1408, 1408),
(1536, 1536, 1536),
(1664, 1664, 1664),
(1792, 1792, 1792),
(1920, 1920, 1920),
(2048, 2048, 2048),
(2176, 2176, 2176),
(2304, 2304, 2304),
(2432, 2432, 2432),
(2560, 2560, 2560),
(2688, 2688, 2688),
(2816, 2816, 2816),
(2944, 2944, 2944),
(3072, 3072, 3072),
(3200, 3200, 3200),
(3328, 3328, 3328),
(3456, 3456, 3456),
(3584, 3584, 3584),
(3712, 3712, 3712),
(3840, 3840, 3840),
(3968, 3968, 3968),
(4096, 4096, 4096),
]


def parse_args(args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="TorchBench Gemm operator Benchmark")
parser.add_argument("--m", default=8, type=int)
parser.add_argument("--k", default=8, type=int)
parser.add_argument("--n", default=8, type=int)
args = parser.parse_args(args)
return args


class Operator(BenchmarkOperator):
USE_BUILTIN_SHAPES = True

def __init__(self, test: str, device: str, extra_args: List[str] = []):
if not extra_args:
self.USE_BUILTIN_SHAPES = True
self.DEFAULT_NUM_BATCH = len(BUILDIN_SHAPES)
self.extra_builtin_metrics = ["speedup", "accuracy"]
else:
self.USE_BUILTIN_SHAPES = False
self.DEFAULT_NUM_BATCH = 1
self.tbargs = parse_args(self.extra_args)
super().__init__(test=test, device=device, extra_args=extra_args)
self.required_metrics = list(
set(self.required_metrics + self.extra_builtin_metrics)
)

@register_benchmark()
def triton_matmul(self, a, b):
return triton_matmul(a, b), a

@register_benchmark(baseline=True)
def aten_matmul(self, a, b):
return torch.matmul(a, b), a

def get_x_val(self, example_inputs) -> float:
# x-value: computation intensity
a, w = example_inputs
m, k = a.size()
k, n = w.size()
# computation intensity for the shape m, n, k
intensity = 1 / (1 / n + 1 / m + 1 / k)
return intensity

@register_metric()
def gbps(self, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
a, w = example_inputs
numel = a.numel() + w.numel() + (torch.mm(a, w).numel())
numel = numel * a.element_size() / 1e9
gbps = list(map(lambda x: numel / x * 1e3, metrics.latency))
return statistics.median(gbps)

@register_metric(skip_baseline=True)
def xShape(self, example_inputs, metrics: BenchmarkOperatorMetrics) -> list[int]:
a, w = example_inputs
m, k = a.size()
k, n = w.size()
return [m, k, n]

@register_metric()
def _tflops(self, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
a, w = example_inputs
m, k = a.size()
k, n = w.size()
flops = m * k * 2 * n
latency = numpy.median(metrics.latency)
return flops / latency / 1e12 * 1e3

def get_input_iter(self) -> Generator:
if self.USE_BUILTIN_SHAPES:
for shape in BUILDIN_SHAPES:
m, k, n = shape
a = torch.randn(
(m, k), device=self.device, dtype=torch.float16
).requires_grad_(False)
w = torch.randn(
(k, n), device=self.device, dtype=torch.float16
).requires_grad_(False)
yield a, w
while True:
yield None
else:
meta_tensor = torch.randn((self.tbargs.m, self.tbargs.k), device="meta")
yield torch.randn_like(meta_tensor, device=self.device).requires_grad(False)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output, loss = fn()
baseline_output, baseline_loss = baseline_fn()
accuracy = True
try:
torch.testing.assert_close(output, baseline_output, rol=1e-5)
# if not (loss == None and baseline_loss == None):
# torch.testing.assert_close(loss.grad, baseline_loss.grad)
except AssertionError:
# either the output tensor or the loss grad tensor does not match
accuracy = False
finally:
return accuracy

def plot(self):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["density"], # argument names to use as an x-axis for the plot
x_vals=self.output.x_vals, # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"triton_matmul",
], # possible values for `line_arg``
line_names=[
"Triton GEMM",
], # label name for the lines
styles=[("blue", "-"), ("green", "-")], # line styles
ylabel="speedup", # label name for the y-axis
plot_name="gemm-performance", # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def _plot(density, provider):
speedup = self.output.get_y_vals(density, provider, "speedup")
return speedup

save_path = "/tmp/test_gemm"

if not os.path.exists(save_path):
os.mkdir(save_path)

_plot.run(show_plots=True, print_data=True, save_path="/tmp/test_gemm")
147 changes: 147 additions & 0 deletions torchbenchmark/operators/gemm/triton_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Triton Matrix Multiplication is from the Triton tutorial:
- https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py
"""

import torch

import triton
import triton.language as tl


# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)

# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)


# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.


def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
ACTIVATION=activation #
)
return c
2 changes: 1 addition & 1 deletion torchbenchmark/operators/softmax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_x_val(self, example_inputs) -> float:
shape = example_inputs[0].size()
return float(shape[1])

@register_metric
@register_metric()
def gbps(self, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
gbps = lambda ms: 2 * example_inputs[0].nelement() * example_inputs[0].element_size() * 1e-9 / (ms * 1e-3)
return list(map(gbps, metrics.latency))
Expand Down
24 changes: 14 additions & 10 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_y_vals(self, x_val, provider, metric_name: str):

def __str__(self):
headers, table = self._table()
table = tabulate.tabulate(table, headers=headers)
table = tabulate.tabulate(table, headers=headers, stralign="right")
return table

def register_benchmark(baseline: bool=False, enabled: bool=True, preprocess: Optional[Callable]=None):
Expand All @@ -127,14 +127,18 @@ def _inner(self, *args, **kwargs):
return _inner
return decorator

def register_metric(func):
operator_name = func.__module__.split(".")[-1]
if not operator_name in REGISTERED_METRICS:
REGISTERED_METRICS[operator_name] = []
REGISTERED_METRICS[operator_name].append(func.__name__)
def _inner(self, *args, **kwargs):
return func(self, *args, **kwargs)
return _inner
def register_metric(skip_baseline: bool=False):
def decorator(func):
operator_name = func.__module__.split(".")[-1]
if not operator_name in REGISTERED_METRICS:
REGISTERED_METRICS[operator_name] = []
REGISTERED_METRICS[operator_name].append(func.__name__)
if skip_baseline:
BASELINE_SKIP_METRICS.append(func.__name__)
def _inner(self, *args, **kwargs):
return func(self, *args, **kwargs)
return _inner
return decorator

def parse_args(op_name: str, args: List[str]) -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -338,7 +342,7 @@ def _do_bench(self,
)
return metric

@register_metric
@register_metric()
def tflops(self, latency: List[float], func: Optional[Callable]=None) -> List[float]:
def _get_flops(self, func: Callable) -> float:
"""By default, use the torch.__dispatch__ based flops counter."""
Expand Down

0 comments on commit fbb768d

Please sign in to comment.