From 9ff1725cd827bc6a786aa2f48a0bfeb2b448dafa Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Sun, 14 Apr 2024 07:32:21 -0700 Subject: [PATCH] Add launch latency benchmarks for triton.CompiledKernel and inductor Summary: There are a number of more detailed views into launch latency that we can get, in addition to the path we get from `triton.JitFunction`: - `triton.compiler.CompiledKernel`, which is the lowest-level interface used by triton - Inductor's `CachingAutotuner.run`, which is the lowest-level lauch interface used by inductor - launching a mostly-nop inductor kernel (can't be truly nop because inductor won't generate a kernel with nothing in it) Reviewed By: xuzhao9, chenyang78 Differential Revision: D56073036 fbshipit-source-id: c72b80eb016a5c2ea27717664e8a1ff0f35c705a --- .../operators/launch_latency/__init__.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/torchbenchmark/operators/launch_latency/__init__.py b/torchbenchmark/operators/launch_latency/__init__.py index 76f140efa6..81130b335f 100644 --- a/torchbenchmark/operators/launch_latency/__init__.py +++ b/torchbenchmark/operators/launch_latency/__init__.py @@ -1,6 +1,9 @@ import torch import triton import triton.language as tl +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +from torch._inductor import triton_heuristics +from torch._inductor.codecache import AsyncCompile from torchbenchmark.util.triton_op import ( BenchmarkOperator, @@ -40,6 +43,55 @@ def nop_with_args_kernel( pass +@torch.compile +def trivial_add_kernel(*args): + return sum([torch.tensor(1.0, device="cuda"), *args]) + + +async_compile = AsyncCompile() + +inductor_nop = async_compile.triton( + "inductor_nop", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor import triton_heuristics + +@triton_heuristics.pointwise( + size_hints=[1], + triton_meta={'signature': {0: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(), equal_to_1=())]}, +) +@triton.jit +def inductor_nop(x): + pass +""", + device_str="cuda", +) + + +inductor_nop_args = async_compile.triton( + "inductor_nop_args", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor import triton_heuristics + +@triton_heuristics.pointwise( + size_hints=[1], + triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9: 'i32', 10: 'i32', 11: 'i32', 12: 'i32', 13: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=(5, 6, 7, 8, 9, 10, 11, 12, 13))]}, +) +@triton.jit +def inductor_nop_args(t1, t2, t3, t4, t5, i1, i2, i3, i4, i5, i6, i7, i8, i9): + pass +""", + device_str="cuda", +) + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["walltime"] @@ -59,6 +111,41 @@ def nop_triton_kernel(self, *args): return lambda: nop_kernel[1,]() return lambda: nop_with_args_kernel[1,](*args) + @register_benchmark() + def nop_triton_compiled_kernel_run(self, *args): + if len(args) == 0: + bin = nop_kernel[1,]() + + else: + bin = nop_with_args_kernel[1,](*args) + args = args[:-5] # remove tl.constexpr args + function = bin.function + metadata = ( + bin.packed_metadata if hasattr(bin, "packed_metadata") else bin.metadata + ) + if hasattr(triton.compiler.CompiledKernel, "launch_metadata"): + return lambda: bin.run( + 1, 1, 1, 0, function, metadata, None, None, None, *args + ) + else: + return lambda: bin.run( + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, function, None, None, metadata, *args + ) + + @register_benchmark() + def nop_inductor_kernel_run(self, *args): + stream = get_raw_stream(0) + grid = triton_heuristics.grid(1) + + if len(args) == 0: + return lambda: inductor_nop.run(1, grid=grid, stream=stream) + args = args[:-5] + return lambda: inductor_nop_args.run(*args, grid=grid, stream=stream) + + @register_benchmark() + def nop_inductor_kernel(self, *args): + return lambda: trivial_add_kernel(*args) + @register_benchmark(baseline=True) def nop_python_function(self, *args): def nop():