Skip to content

Commit

Permalink
TritonBench: add benchmark for welford
Browse files Browse the repository at this point in the history
Summary:
Based on inductor generated code, but modified to use Triton's tuning
pytorch github: pytorch/pytorch#120184

Reviewed By: bertmaher

Differential Revision: D56306884

fbshipit-source-id: 5adb963cd8d5339275abdfd24660022174cc0e0f
  • Loading branch information
manman-ren authored and facebook-github-bot committed May 8, 2024
1 parent d6b44d2 commit 02d3328
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/welford/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
70 changes: 70 additions & 0 deletions torchbenchmark/operators/welford/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional

import numpy
import torch
import triton


from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .triton_welford import fused_native_layer_norm as triton_welford
from .triton_welford import fused_native_layer_norm_no_welford as triton_no_welford
from torch._dynamo.testing import rand_strided


BUILDIN_SHAPES = [
(262144, 1024),
(262144, 1536),
(262144, 2048),
(262144, 2560),
(262144, 3072),
(262144, 4096),
(262144, 5120),
(262144, 6144),
(262144, 7168),
(262144, 8192),
]


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.shapes = BUILDIN_SHAPES

@register_benchmark()
def test_welford(self, p1, p2, p3) -> Callable:
return lambda: triton_welford(p1, p2, p3)

@register_benchmark(baseline=True)
def test_no_welford(self, p1, p2, p3) -> Callable:
return lambda: triton_no_welford(p1, p2, p3)

def get_x_val(self, example_inputs) -> float:
p1, p2, p3 = example_inputs
s, d = p3.size()
return d

def get_input_iter(self) -> Generator:
for shape in self.shapes:
s, d = shape
p1 = rand_strided((d, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
p2 = rand_strided((d, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
p3 = rand_strided((s, d), (d, 1), device='cuda:0', dtype=torch.bfloat16)
yield p1, p2, p3

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output)

211 changes: 211 additions & 0 deletions torchbenchmark/operators/welford/triton_welford.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""
Based on https://github.com/pytorch/pytorch/issues/120184.
Generated from Inductor for forward layernorm with and without welford
"""

import torch

import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.runtime.triton_helpers import libdevice
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor


@triton.autotune(
configs=[
triton.Config(
{
"XBLOCK": 1,
"RBLOCK": 1024,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 1,
"RBLOCK": 2048,
},
num_stages=1,
num_warps=8,
),
],
key=["xnumel", "rnumel"],
)
@triton.jit
def triton_red_fused_native_layer_norm_0(
in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp3_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp3_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp3_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
)
tmp3_mean = tl.where(rmask, tmp3_mean_next, tmp3_mean)
tmp3_m2 = tl.where(rmask, tmp3_m2_next, tmp3_m2)
tmp3_weight = tl.where(rmask, tmp3_weight_next, tmp3_weight)
tmp3_tmp, tmp4_tmp, tmp5_tmp = triton_helpers.welford(
tmp3_mean, tmp3_m2, tmp3_weight, 1
)
tmp3 = tmp3_tmp[:, None]
tmp4 = tmp4_tmp[:, None]
tmp5 = tmp5_tmp[:, None]
tl.store(out_ptr0 + (x0), tmp3, None)
tmp6 = rnumel
tmp7 = tmp4 / tmp6
tmp8 = 1e-05
tmp9 = tmp7 + tmp8
tmp10 = libdevice.rsqrt(tmp9)
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp10, None)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_first').to(tl.float32)
tmp15 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last').to(tl.float32)
tmp18 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last').to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp12 - tmp3
tmp14 = tmp13 * tmp10
tmp16 = tmp15.to(tl.float32)
tmp17 = tmp14 * tmp16
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp17 + tmp19
tmp21 = tmp20.to(tl.float32)
tl.store(out_ptr1 + (r1 + (rnumel*x0)), tmp21, rmask)


@triton.autotune(
configs=[
triton.Config(
{
"XBLOCK": 1,
"RBLOCK": 1024,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 1,
"RBLOCK": 2048,
},
num_stages=1,
num_warps=8,
),
],
key=["xnumel", "rnumel"],
)
@triton.jit
def triton_red_fused_native_layer_norm_no_welford(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = _tmp3 + tmp2
_tmp3 = tmp4
tmp3 = tl.sum(_tmp3, 1)[:, None]
tmp5 = rnumel #4096.0
tmp6 = tmp3 / tmp5
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp6, None)
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp7 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_last').to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp8 - tmp6
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tmp13
tmp12 = tl.sum(_tmp12, 1)[:, None]
tmp14 = rnumel #4096.0
tmp15 = tmp12 / tmp14
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tl.debug_barrier()
tl.store(in_out_ptr1 + (x0), tmp18, None)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp19 = tl.load(in_ptr0 + (r1 + (rnumel*x0)), rmask, eviction_policy='evict_first').to(tl.float32)
tmp23 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last').to(tl.float32)
tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last').to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp20 - tmp6
tmp22 = tmp21 * tmp18
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp22 * tmp24
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp25 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr0 + (r1 + (rnumel*x0)), tmp29, rmask)


# 2048*128, D can be in range(256, 4096+1, 256) or range(8192, 8192*11+1, 8192)
def fused_native_layer_norm_no_welford(primals_1, primals_2, primals_3):
S, D = primals_3.shape
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((S, 1), (1, S), torch.float32)
buf1 = buf0; del buf0 # reuse
buf2 = empty_strided_cuda((S, 1), (1, S), torch.float32)
buf3 = reinterpret_tensor(buf2, (S, 1), (1, 1), 0); del buf2 # reuse
buf4 = empty_strided_cuda((S, D), (D, 1), torch.bfloat16)
# Source Nodes: [fn], Original ATen: [aten.native_layer_norm]
stream0 = get_raw_stream(0)
grid = lambda META: (
triton.cdiv(S, META["XBLOCK"]),
)
triton_red_fused_native_layer_norm_no_welford[grid](buf1, buf3, primals_3, primals_1, primals_2, buf4, S, D)
#del primals_1
#del primals_2
return (buf4, primals_3, buf1, buf3, )

def fused_native_layer_norm(primals_1, primals_2, primals_3):
S, D = primals_3.shape
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((S, 1), (1, 1), torch.float32)
buf1 = empty_strided_cuda((S, 1), (1, S), torch.float32)
buf3 = reinterpret_tensor(buf1, (S, 1), (1, 1), 0); del buf1 # reuse
buf4 = empty_strided_cuda((S, D), (D, 1), torch.bfloat16)
# Source Nodes: [fn], Original ATen: [aten.native_layer_norm]
stream0 = get_raw_stream(0)
grid = lambda META: (
triton.cdiv(S, META["XBLOCK"]),
)
triton_red_fused_native_layer_norm_0[grid](buf3, primals_3, primals_1, primals_2, buf0, buf4, S, D)
#del primals_1
#del primals_2
return (buf4, primals_3, buf0, buf3, )

0 comments on commit 02d3328

Please sign in to comment.