-
Notifications
You must be signed in to change notification settings - Fork 280
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Based on inductor generated code, but modified to use Triton's tuning pytorch github: pytorch/pytorch#124369 The base variant is prior to OSS pytorch/pytorch#124356. This PR improves performance for template attention. The second variant is after the PR. Reviewed By: bertmaher Differential Revision: D56372010 fbshipit-source-id: 4439113a92fd41b81269af1227deaf5ec52c65dc
- Loading branch information
1 parent
02d3328
commit 1708afa
Showing
3 changed files
with
439 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
|
||
import csv | ||
import os | ||
import statistics | ||
from typing import Any, Callable, Generator, List, Optional | ||
|
||
import numpy | ||
import torch | ||
import triton | ||
|
||
|
||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
from .triton_attention import triton_attention_no_exp2 as triton_test_no_exp2 | ||
from .triton_attention import triton_attention_with_exp2 as triton_test_with_exp2 | ||
from torch._dynamo.testing import rand_strided | ||
|
||
|
||
BUILDIN_SHAPES = [ | ||
(16, 16, 4096, 64), | ||
] | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
DEFAULT_METRICS = ["latency", "speedup", "accuracy"] | ||
|
||
def __init__(self, mode: str, device: str, extra_args: List[str] = []): | ||
super().__init__(mode=mode, device=device, extra_args=extra_args) | ||
self.shapes = BUILDIN_SHAPES | ||
|
||
@register_benchmark(baseline=True) | ||
def test_no_exp2(self, p1, p2, p3) -> Callable: | ||
return lambda: triton_test_no_exp2(p1, p2, p3) | ||
|
||
@register_benchmark() | ||
def test_with_exp2(self, p1, p2, p3) -> Callable: | ||
return lambda: triton_test_with_exp2(p1, p2, p3) | ||
|
||
def get_x_val(self, example_inputs) -> float: | ||
p1, p2, p3 = example_inputs | ||
batch_size, num_heads, num_queries, m = p3.size() | ||
return num_queries | ||
|
||
def get_input_iter(self) -> Generator: | ||
for shape in self.shapes: | ||
batch_size, num_heads, num_queries, m = shape | ||
arg0_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16) | ||
arg1_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16) | ||
arg2_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16) | ||
yield arg0_1, arg1_1, arg2_1 | ||
|
||
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: | ||
output = fn() | ||
baseline_output = baseline_fn() | ||
return torch.allclose(output, baseline_output) | ||
|
Oops, something went wrong.