-
Notifications
You must be signed in to change notification settings - Fork 280
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TritonBench: add benchmark for welford
Summary: Based on inductor generated code, but modified to use Triton's tuning pytorch github: pytorch/pytorch#120184 Reviewed By: bertmaher Differential Revision: D56306884 fbshipit-source-id: 5adb963cd8d5339275abdfd24660022174cc0e0f
- Loading branch information
1 parent
d6b44d2
commit 02d3328
Showing
3 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
|
||
import csv | ||
import os | ||
import statistics | ||
from typing import Any, Callable, Generator, List, Optional | ||
|
||
import numpy | ||
import torch | ||
import triton | ||
|
||
|
||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
from .triton_welford import fused_native_layer_norm as triton_welford | ||
from .triton_welford import fused_native_layer_norm_no_welford as triton_no_welford | ||
from torch._dynamo.testing import rand_strided | ||
|
||
|
||
BUILDIN_SHAPES = [ | ||
(262144, 1024), | ||
(262144, 1536), | ||
(262144, 2048), | ||
(262144, 2560), | ||
(262144, 3072), | ||
(262144, 4096), | ||
(262144, 5120), | ||
(262144, 6144), | ||
(262144, 7168), | ||
(262144, 8192), | ||
] | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
DEFAULT_METRICS = ["latency", "speedup", "accuracy"] | ||
|
||
def __init__(self, mode: str, device: str, extra_args: List[str] = []): | ||
super().__init__(mode=mode, device=device, extra_args=extra_args) | ||
self.shapes = BUILDIN_SHAPES | ||
|
||
@register_benchmark() | ||
def test_welford(self, p1, p2, p3) -> Callable: | ||
return lambda: triton_welford(p1, p2, p3) | ||
|
||
@register_benchmark(baseline=True) | ||
def test_no_welford(self, p1, p2, p3) -> Callable: | ||
return lambda: triton_no_welford(p1, p2, p3) | ||
|
||
def get_x_val(self, example_inputs) -> float: | ||
p1, p2, p3 = example_inputs | ||
s, d = p3.size() | ||
return d | ||
|
||
def get_input_iter(self) -> Generator: | ||
for shape in self.shapes: | ||
s, d = shape | ||
p1 = rand_strided((d, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | ||
p2 = rand_strided((d, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | ||
p3 = rand_strided((s, d), (d, 1), device='cuda:0', dtype=torch.bfloat16) | ||
yield p1, p2, p3 | ||
|
||
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: | ||
output = fn() | ||
baseline_output = baseline_fn() | ||
return torch.allclose(output, baseline_output) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
""" | ||
Based on https://github.com/pytorch/pytorch/issues/120184. | ||
Generated from Inductor for forward layernorm with and without welford | ||
""" | ||
|
||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
from torch._inductor.runtime import triton_helpers, triton_heuristics | ||
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | ||
from torch._inductor.runtime.triton_helpers import libdevice | ||
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | ||
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
triton.Config( | ||
{ | ||
"XBLOCK": 1, | ||
"RBLOCK": 1024, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 1, | ||
"RBLOCK": 2048, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
], | ||
key=["xnumel", "rnumel"], | ||
) | ||
@triton.jit | ||
def triton_red_fused_native_layer_norm_0( | ||
in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | ||
xoffset = tl.program_id(0) * XBLOCK | ||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | ||
xmask = xindex < xnumel | ||
rbase = tl.arange(0, RBLOCK)[None, :] | ||
x0 = xindex | ||
tmp3_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32) | ||
tmp3_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32) | ||
tmp3_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32) | ||
for roffset in range(0, rnumel, RBLOCK): | ||
rindex = roffset + rbase | ||
rmask = rindex < rnumel | ||
r1 = rindex | ||
tmp0 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp1 = tmp0.to(tl.float32) | ||
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | ||
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce( | ||
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0 | ||
) | ||
tmp3_mean = tl.where(rmask, tmp3_mean_next, tmp3_mean) | ||
tmp3_m2 = tl.where(rmask, tmp3_m2_next, tmp3_m2) | ||
tmp3_weight = tl.where(rmask, tmp3_weight_next, tmp3_weight) | ||
tmp3_tmp, tmp4_tmp, tmp5_tmp = triton_helpers.welford( | ||
tmp3_mean, tmp3_m2, tmp3_weight, 1 | ||
) | ||
tmp3 = tmp3_tmp[:, None] | ||
tmp4 = tmp4_tmp[:, None] | ||
tmp5 = tmp5_tmp[:, None] | ||
tl.store(out_ptr0 + (x0), tmp3, None) | ||
tmp6 = rnumel | ||
tmp7 = tmp4 / tmp6 | ||
tmp8 = 1e-05 | ||
tmp9 = tmp7 + tmp8 | ||
tmp10 = libdevice.rsqrt(tmp9) | ||
tl.debug_barrier() | ||
tl.store(in_out_ptr0 + (x0), tmp10, None) | ||
for roffset in range(0, rnumel, RBLOCK): | ||
rindex = roffset + rbase | ||
rmask = rindex < rnumel | ||
r1 = rindex | ||
tmp11 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_first').to(tl.float32) | ||
tmp15 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp18 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp12 = tmp11.to(tl.float32) | ||
tmp13 = tmp12 - tmp3 | ||
tmp14 = tmp13 * tmp10 | ||
tmp16 = tmp15.to(tl.float32) | ||
tmp17 = tmp14 * tmp16 | ||
tmp19 = tmp18.to(tl.float32) | ||
tmp20 = tmp17 + tmp19 | ||
tmp21 = tmp20.to(tl.float32) | ||
tl.store(out_ptr1 + (r1 + (rnumel*x0)), tmp21, rmask) | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
triton.Config( | ||
{ | ||
"XBLOCK": 1, | ||
"RBLOCK": 1024, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
triton.Config( | ||
{ | ||
"XBLOCK": 1, | ||
"RBLOCK": 2048, | ||
}, | ||
num_stages=1, | ||
num_warps=8, | ||
), | ||
], | ||
key=["xnumel", "rnumel"], | ||
) | ||
@triton.jit | ||
def triton_red_fused_native_layer_norm_no_welford(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | ||
xoffset = tl.program_id(0) * XBLOCK | ||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | ||
xmask = xindex < xnumel | ||
rbase = tl.arange(0, RBLOCK)[None, :] | ||
x0 = xindex | ||
_tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | ||
for roffset in range(0, rnumel, RBLOCK): | ||
rindex = roffset + rbase | ||
rmask = rindex < rnumel | ||
r1 = rindex | ||
tmp0 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp1 = tmp0.to(tl.float32) | ||
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | ||
tmp4 = _tmp3 + tmp2 | ||
_tmp3 = tmp4 | ||
tmp3 = tl.sum(_tmp3, 1)[:, None] | ||
tmp5 = rnumel #4096.0 | ||
tmp6 = tmp3 / tmp5 | ||
tl.debug_barrier() | ||
tl.store(in_out_ptr0 + (x0), tmp6, None) | ||
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | ||
for roffset in range(0, rnumel, RBLOCK): | ||
rindex = roffset + rbase | ||
rmask = rindex < rnumel | ||
r1 = rindex | ||
tmp7 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp8 = tmp7.to(tl.float32) | ||
tmp9 = tmp8 - tmp6 | ||
tmp10 = tmp9 * tmp9 | ||
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | ||
tmp13 = _tmp12 + tmp11 | ||
_tmp12 = tmp13 | ||
tmp12 = tl.sum(_tmp12, 1)[:, None] | ||
tmp14 = rnumel #4096.0 | ||
tmp15 = tmp12 / tmp14 | ||
tmp16 = 1e-05 | ||
tmp17 = tmp15 + tmp16 | ||
tmp18 = libdevice.rsqrt(tmp17) | ||
tl.debug_barrier() | ||
tl.store(in_out_ptr1 + (x0), tmp18, None) | ||
for roffset in range(0, rnumel, RBLOCK): | ||
rindex = roffset + rbase | ||
rmask = rindex < rnumel | ||
r1 = rindex | ||
tmp19 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_first').to(tl.float32) | ||
tmp23 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last').to(tl.float32) | ||
tmp20 = tmp19.to(tl.float32) | ||
tmp21 = tmp20 - tmp6 | ||
tmp22 = tmp21 * tmp18 | ||
tmp24 = tmp23.to(tl.float32) | ||
tmp25 = tmp22 * tmp24 | ||
tmp27 = tmp26.to(tl.float32) | ||
tmp28 = tmp25 + tmp27 | ||
tmp29 = tmp28.to(tl.float32) | ||
tl.store(out_ptr0 + (r1 + (rnumel*x0)), tmp29, rmask) | ||
|
||
|
||
# 2048*128, D can be in range(256, 4096+1, 256) or range(8192, 8192*11+1, 8192) | ||
def fused_native_layer_norm_no_welford(primals_1, primals_2, primals_3): | ||
S, D = primals_3.shape | ||
with torch.cuda._DeviceGuard(0): | ||
torch.cuda.set_device(0) | ||
buf0 = empty_strided_cuda((S, 1), (1, S), torch.float32) | ||
buf1 = buf0; del buf0 # reuse | ||
buf2 = empty_strided_cuda((S, 1), (1, S), torch.float32) | ||
buf3 = reinterpret_tensor(buf2, (S, 1), (1, 1), 0); del buf2 # reuse | ||
buf4 = empty_strided_cuda((S, D), (D, 1), torch.bfloat16) | ||
# Source Nodes: [fn], Original ATen: [aten.native_layer_norm] | ||
stream0 = get_raw_stream(0) | ||
grid = lambda META: ( | ||
triton.cdiv(S, META["XBLOCK"]), | ||
) | ||
triton_red_fused_native_layer_norm_no_welford[grid](buf1, buf3, primals_3, primals_1, primals_2, buf4, S, D) | ||
#del primals_1 | ||
#del primals_2 | ||
return (buf4, primals_3, buf1, buf3, ) | ||
|
||
def fused_native_layer_norm(primals_1, primals_2, primals_3): | ||
S, D = primals_3.shape | ||
with torch.cuda._DeviceGuard(0): | ||
torch.cuda.set_device(0) | ||
buf0 = empty_strided_cuda((S, 1), (1, 1), torch.float32) | ||
buf1 = empty_strided_cuda((S, 1), (1, S), torch.float32) | ||
buf3 = reinterpret_tensor(buf1, (S, 1), (1, 1), 0); del buf1 # reuse | ||
buf4 = empty_strided_cuda((S, D), (D, 1), torch.bfloat16) | ||
# Source Nodes: [fn], Original ATen: [aten.native_layer_norm] | ||
stream0 = get_raw_stream(0) | ||
grid = lambda META: ( | ||
triton.cdiv(S, META["XBLOCK"]), | ||
) | ||
triton_red_fused_native_layer_norm_0[grid](buf3, primals_3, primals_1, primals_2, buf0, buf4, S, D) | ||
#del primals_1 | ||
#del primals_2 | ||
return (buf4, primals_3, buf0, buf3, ) |