-
Notifications
You must be signed in to change notification settings - Fork 280
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Based on PT2 test case: pytorch/pytorch#121661 Reviewed By: bertmaher Differential Revision: D56437647 fbshipit-source-id: 0d735c443c9cea419dce971fd6ea796ca444c0ff
- Loading branch information
1 parent
6f3faa0
commit 4c7ec3a
Showing
3 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
Based on PT2 test case: https://github.com/pytorch/pytorch/issues/121661 | ||
Motivated by https://www.thonking.ai/p/short-supporting-mixtral-in-gpt-fast, | ||
gather + gemv is the primary kernel driving mixtral perf. | ||
""" | ||
|
||
import csv | ||
import os | ||
import statistics | ||
from typing import Any, Callable, Generator, List, Optional | ||
|
||
import numpy | ||
import torch | ||
import triton | ||
|
||
|
||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
from .triton_gather_gemv import triton_gemv_0 as triton_test_0 | ||
from torch._dynamo.testing import rand_strided | ||
|
||
class Operator(BenchmarkOperator): | ||
|
||
@register_metric() | ||
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): | ||
arg0_1, arg1_1, arg2_1 = example_inputs | ||
gbps = ( | ||
lambda ms: 2 | ||
* arg2_1.size(0) * arg2_1.size(0) | ||
* arg0_1.element_size() | ||
/ ms | ||
* 1e-6 | ||
) | ||
return list(map(gbps, metrics.latency)) | ||
|
||
def __init__(self, mode: str, device: str, extra_args: List[str] = []): | ||
super().__init__(mode=mode, device=device, extra_args=extra_args) | ||
|
||
@register_benchmark(baseline=True) | ||
def test_0(self, p1, p2, p3) -> Callable: | ||
return lambda: triton_test_0(p1, p2, p3) | ||
|
||
@register_benchmark(baseline=True) | ||
def test_eager(self, w, idx, x): | ||
return lambda: w[idx].to(x.dtype) @ x | ||
|
||
@register_benchmark() | ||
def test_inductor(self, w, idx, x): | ||
@torch.compile | ||
def gather_gemv(w, idx, x): | ||
return w[idx].to(x.dtype) @ x | ||
|
||
gather_gemv(w, idx, x) # warmup | ||
return lambda: gather_gemv(w, idx, x) | ||
|
||
def get_x_val(self, example_inputs) -> float: | ||
arg0_1, arg1_1, arg2_1 = example_inputs | ||
s = arg2_1.size(0) | ||
return s | ||
|
||
def get_input_iter(self) -> Generator: | ||
for i in range(11, 15): | ||
S = 2 ** i | ||
arg0_1 = rand_strided((8, S, S), (S*S, S, 1), device='cuda:0', dtype=torch.int8) | ||
arg1_1 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) | ||
arg2_1 = rand_strided((S, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | ||
yield arg0_1, arg1_1, arg2_1 | ||
|
116 changes: 116 additions & 0 deletions
116
torchbenchmark/operators/gather_gemv/triton_gather_gemv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Based on https://github.com/pytorch/pytorch/issues/121661 | ||
""" | ||
|
||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | ||
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor | ||
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
triton.Config( | ||
{ | ||
"XBLOCK": 1, | ||
"RBLOCK": 2048, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 64, | ||
"RBLOCK": 8, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 64, | ||
"RBLOCK": 4, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 8, | ||
"RBLOCK": 512, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 8, | ||
"RBLOCK": 256, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 64, | ||
"RBLOCK": 64, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
], | ||
key=["xnumel", "rnumel"], | ||
) | ||
@triton.jit | ||
def triton_red_fused_mv_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | ||
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK | ||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) | ||
xmask = xindex < xnumel | ||
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64) | ||
x0 = xindex | ||
# x0 // rnumel should have the same value of either 0 or 1 | ||
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy='evict_last') | ||
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | ||
for roffset in range(0, rnumel, RBLOCK): | ||
rindex = roffset + rbase | ||
rmask = rindex < rnumel | ||
r1 = rindex # size (1, RBLOCK) | ||
tmp7 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32) | ||
tmp1 = tmp0 + 8 | ||
tmp2 = tmp0 < 0 | ||
tmp3 = tl.where(tmp2, tmp1, tmp0) # size (XBLOCK, 1) | ||
# in_ptr1 has (B, S, S) shape, tmp3 is the 2nd dimension with stride of S * S. | ||
tmp4 = tl.load(in_ptr1 + (r1 + (rnumel*(x0 % rnumel)) + (rnumel*rnumel*tmp3)), None, eviction_policy='evict_first') | ||
tmp5 = tmp4.to(tl.float32) | ||
tmp6 = tmp5.to(tl.float32) | ||
tmp8 = tmp7.to(tl.float32) | ||
tmp9 = tmp6 * tmp8 # (XBLOCK, RBLOCK) * (1, RBLOCK) | ||
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | ||
tmp12 = _tmp11 + tmp10 | ||
_tmp11 = tmp12 | ||
tmp11 = tl.sum(_tmp11, 1)[:, None] | ||
tmp13 = tmp11.to(tl.float32) | ||
tl.store(out_ptr1 + (x0), tmp13, None) | ||
|
||
|
||
def triton_gemv_0(arg0_1, arg1_1, arg2_1): | ||
S, = arg2_1.shape | ||
assert_size_stride(arg0_1, (8, S, S), (S*S, S, 1)) | ||
assert_size_stride(arg1_1, (2, ), (1, )) | ||
assert_size_stride(arg2_1, (S, ), (1, )) | ||
xnumel = 2*S | ||
rnumel = S | ||
with torch.cuda._DeviceGuard(0): | ||
torch.cuda.set_device(0) | ||
# size will be double | ||
buf1 = empty_strided_cuda((2*S, ), (1, ), torch.bfloat16) | ||
|
||
grid = lambda META: ( | ||
triton.cdiv(2*S, META["XBLOCK"]), | ||
) | ||
triton_red_fused_mv_0[grid](arg1_1, arg0_1, arg2_1, buf1, xnumel, rnumel) | ||
return (reinterpret_tensor(buf1, (2, S), (S, 1), 0), ) | ||
|