Skip to content

Commit

Permalink
template attention from PT2
Browse files Browse the repository at this point in the history
Summary:
Based on inductor generated code, but modified to use Triton's tuning
pytorch github: pytorch/pytorch#124369
The base variant is prior to OSS pytorch/pytorch#124356. This PR improves performance for template attention. The second variant is after the PR.

Reviewed By: bertmaher

Differential Revision: D56372010

fbshipit-source-id: 4439113a92fd41b81269af1227deaf5ec52c65dc
  • Loading branch information
manman-ren authored and facebook-github-bot committed May 8, 2024
1 parent 02d3328 commit 1708afa
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/template_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
61 changes: 61 additions & 0 deletions torchbenchmark/operators/template_attention/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional

import numpy
import torch
import triton


from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .triton_attention import triton_attention_no_exp2 as triton_test_no_exp2
from .triton_attention import triton_attention_with_exp2 as triton_test_with_exp2
from torch._dynamo.testing import rand_strided


BUILDIN_SHAPES = [
(16, 16, 4096, 64),
]


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.shapes = BUILDIN_SHAPES

@register_benchmark(baseline=True)
def test_no_exp2(self, p1, p2, p3) -> Callable:
return lambda: triton_test_no_exp2(p1, p2, p3)

@register_benchmark()
def test_with_exp2(self, p1, p2, p3) -> Callable:
return lambda: triton_test_with_exp2(p1, p2, p3)

def get_x_val(self, example_inputs) -> float:
p1, p2, p3 = example_inputs
batch_size, num_heads, num_queries, m = p3.size()
return num_queries

def get_input_iter(self) -> Generator:
for shape in self.shapes:
batch_size, num_heads, num_queries, m = shape
arg0_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
arg2_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
yield arg0_1, arg1_1, arg2_1

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output)

Loading

0 comments on commit 1708afa

Please sign in to comment.