Skip to content

Commit

Permalink
FA Kernel Update for Accuracy and Performance (#45)
Browse files Browse the repository at this point in the history
Major changes:

1. Fix numerical errors from scaling input tensors with log_2(e) as
preprocessing. Fudge factors are adjusted accordingly
2. Adopt techniques from forward kernel to specialize inner loops of the
bwd kernel as well.
3. Update the tuning database for MI200/300 accordingly

Minor changes:

1. `pyaotriton` now includes `$ORIGIN` in its `DT_RUNPATH`
2. `install` target now installs `pyaotriton` to
`$CMAKE_INSTALL_PREFIX/lib`
3. `mptune` now stores testing results' batch size, making the timing
results more informative
4. `performance_*.py` scripts now read `USE_TFLOPS`, `D_HEADS`, and
`N_CTX` env vars, allowing changing the testing size without editing the
code
5. `test/test_backward.py` now displays target fudge factors for fudge
factor adjustment
6. `tune_flash.py` now shrinks batch size to 2 when both sequence
lengths > 4096, to not exceed the VRAM limit.
7. Fix a problem of `sancheck_lut_tensor` in `class
FlashKernel(KernelDescription)`, which did not handle single element LUT
tensor correctly.
8. `v2python/table_tool.py` now ignores `inputs$BATCH` column

Notes:

1. The fudge factors in use assume PyTorch <= 2.4. See
pytorch/pytorch#135590 for detailed discussion
why PyTorch 2.5 cannot be used for testing. PyTorch 2.6 will include a
new interface to fix the problem.
  • Loading branch information
xinyazhang authored Oct 4, 2024
1 parent e43acd9 commit f6b28a9
Show file tree
Hide file tree
Showing 19 changed files with 469 additions and 420 deletions.
4 changes: 4 additions & 0 deletions bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ if(AOTRITON_BUILD_FOR_TUNING)
else(AOTRITON_BUILD_FOR_TUNING)
target_compile_definitions(pyaotriton PRIVATE -DAOTRITON_BUILD_FOR_TUNING=0)
endif(AOTRITON_BUILD_FOR_TUNING)

set_target_properties(pyaotriton PROPERTIES INSTALL_RPATH "$ORIGIN")
include(GNUInstallDirs)
install(TARGETS pyaotriton LIBRARY DESTINATION lib)
1 change: 1 addition & 0 deletions test/mptune/flash/db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def constrcut_inputs(self, request):
head_dim_rounded = max(16, head_dim_rounded)
inputs = {
'Q_dtype': str(dtype),
'BATCH' : BATCH,
'N_HEADS': N_HEADS,
'D_HEAD': D_HEAD,
'max_seqlen_q': seqlen_q,
Expand Down
30 changes: 24 additions & 6 deletions test/performance_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import os
import pytest
import torch

import triton
from collections import defaultdict
from attn_torch_function import attention, AttentionExtraArgs

try:
Expand All @@ -19,23 +21,36 @@
except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None
USE_TFLOPS = bool(int(os.getenv('USE_TFLOPS', default='1')))
print(f'{USE_TFLOPS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 4096, 64
d_heads = os.getenv('D_HEADS', default='64,128')
d_heads = list(map(lambda x: int(x), d_heads.split(',')))

n_ctx = os.getenv('N_CTX', default=list(range(10, 14)))
if isinstance(n_ctx, str):
n_ctx = map(lambda x: int(x), n_ctx.split(','))
X_VALS = list(map(lambda x: 2 ** x, n_ctx))
print(f'{X_VALS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# BATCH, N_HEADS, N_CTX, D_HEAD = 512, 32, 512, 64
# vary seq length for fixed head and batch=4
configs = []
for mode in ['bwd']:
# for causal in [False, True]:
for causal in [False]:
for D_HEAD in [64, 128]:
for D_HEAD in d_heads:
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=list(X_VALS),
# x_vals=[2**i for i in range(10, 15)],
x_vals=[2**13],
# x_vals=[2**13],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
line_names=['Triton(TFLOPS)' if USE_TFLOPS else 'Triton(ms)'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
ylabel='TFLOPS' if USE_TFLOPS else 'ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
Expand Down Expand Up @@ -97,7 +112,10 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
total_flops *= 0.5
if mode == 'bwd':
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
if USE_TFLOPS:
return total_flops / ms * 1e-9
else:
return ms


# only works on post-Ampere GPUs right now
Expand Down
29 changes: 23 additions & 6 deletions test/performance_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch

import os
import triton
from attn_torch_function import attention, AttentionExtraArgs

Expand All @@ -19,23 +20,35 @@
except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None
USE_TFLOPS = bool(int(os.getenv('USE_TFLOPS', default='1')))
print(f'{USE_TFLOPS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 8, 64, 4096, 64
d_heads = os.getenv('D_HEADS', default='64,128')
d_heads = list(map(lambda x: int(x), d_heads.split(',')))

n_ctx = os.getenv('N_CTX', default=list(range(10, 14)))
if isinstance(n_ctx, str):
n_ctx = map(lambda x: int(x), n_ctx.split(','))
X_VALS = list(map(lambda x: 2 ** x, n_ctx))
print(f'{X_VALS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
# for causal in [False, True]:
for causal in [False]:
for D_HEAD in [64, 128]:
for D_HEAD in d_heads:
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=list(X_VALS),
# x_vals=[2**i for i in range(10, 15)],
x_vals=[2**13],
# x_vals=[2**13],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
line_names=['Triton(TFLOPS)' if USE_TFLOPS else 'Triton(ms)'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
ylabel='TFLOPS' if USE_TFLOPS else 'ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
Expand All @@ -47,6 +60,7 @@
})
)


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
print(f"{N_CTX=}")
Expand Down Expand Up @@ -97,7 +111,10 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
total_flops *= 0.5
if mode == 'bwd':
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
if USE_TFLOPS:
return total_flops / ms * 1e-9
else:
return ms


# only works on post-Ampere GPUs right now
Expand Down
2 changes: 1 addition & 1 deletion test/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale
is_allclose, adiff, grads_allclose, grads_adiff, tfts = ctx.validate_with_reference(tri_out, ctx.dout_tensors, return_target_fudge_factors=True)
ctx.display_validation_results(tri_out, is_allclose, adiff, grads_allclose, grads_adiff)

assert is_allclose, f'Forward pass {is_allclose=}'
assert is_allclose, f'Forward pass {is_allclose=} {tfts=}'
dq_allclose, dk_allclose, dv_allclose, db_allclose = grads_allclose
tri_dq, tri_dk, tri_dv, tri_db = ctx.dout_tensors
ref_dq, ref_dk, ref_dv, ref_db = ctx.dref_tensors
Expand Down
8 changes: 6 additions & 2 deletions test/tune_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def next_index(self, kig: KernelIndexProress) -> bool:
class FlashTunerSource(MonadService):
def gen(self):
a = self._args
yield from itertools.product(a.batch, a.n_heads, a.d_head, a.seqlen_q, a.seqlen_k, a.causal, a.sm_scale, a.dropout_p, a.return_encoded_softmax, a.dtype, a.bias_type)
for tup in itertools.product(a.batch, a.n_heads, a.d_head, a.seqlen_q, a.seqlen_k, a.causal, a.sm_scale, a.dropout_p, a.return_encoded_softmax, a.dtype, a.bias_type):
BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, return_encoded_softmax, dtype, bias_type = tup
if seqlen_q > 4096 and seqlen_k > 4096:
BATCH = 2
yield (BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, return_encoded_softmax, dtype, bias_type)

def init(self, _):
pass
Expand Down Expand Up @@ -227,7 +231,7 @@ def make_ui(manager : TunerManager):

def parse():
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('--batch', type=int, nargs=1, default=[1], help='(Not a functional) Batch size.')
p.add_argument('--batch', type=int, nargs=1, default=[8], help='(Not a functional) Batch size.')
p.add_argument('--n_heads', type=int, nargs=1, default=[12], help='(Not a functional) Number of heads')
p.add_argument('--sm_scale', type=float, nargs=1, default=[1.2], help='(Not a functional) Softmax Scale')
p.add_argument('--return_encoded_softmax', type=bool, default=[False],
Expand Down
24 changes: 13 additions & 11 deletions tritonsrc/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def __init__(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype,

# Maximal value from tune_flash.py and table_tool.py --fudge_factor_tolerance 5.0
# Note: Navi 3x is experimental and YMMV
self.OUT_FUDGE_FACTOR = 6.0 if dtype != torch.float32 else 10.0
self.OUT_FUDGE_FACTOR = 3.0
if dtype == torch.float32:
self.OUT_FUDGE_FACTOR = 12.0

'''
Create Tensors that will be kept b/w forward and backward pass
Expand Down Expand Up @@ -245,18 +247,17 @@ def _compute_fudge_factors(self, p : SdpaParams):

# Maximal value from tune_flash.py and table_tool.py --fudge_factor_tolerance 5.0
# Note: Navi 3x is experimental and YMMV
query_fudge_factor = 180.0
key_fudge_factor = 16.0
value_fudge_factor = 32.0
query_fudge_factor = 32.0
key_fudge_factor = 48.0
value_fudge_factor = 16.0
bias_fudge_factor = 16.0
print(f'{torch.cuda.get_device_properties(0).gcnArchName=}')
# print(f'{torch.cuda.get_device_properties(0).gcnArchName=}')
if torch.version.hip:
if 'gfx90a' in torch.cuda.get_device_properties(0).gcnArchName:
key_fudge_factor = max(8.0, (seqlen_k + seqlen_q) / 16.0) # TODO: Check why
bias_fudge_factor = 32.0
query_fudge_factor = 80.0
if dtype == torch.float32:
key_fudge_factor = 180.0
bias_fudge_factor = 32.0
bias_fudge_factor = 24.0
return (query_fudge_factor, key_fudge_factor, value_fudge_factor, bias_fudge_factor)

@staticmethod
Expand Down Expand Up @@ -321,9 +322,10 @@ def lmax(x) -> float:
atol = default_atol[torch.float32]
threshold = max(atol, ref_error * fudge_factor)
valid = test_error <= threshold
tft = test_error / ref_error if ref_error > atol else 1.0
tft = test_error / ref_error if ref_error * fudge_factor > atol else 1.0
if not valid:
print(f'For {tname}, Consider bump fudge_factor to {tft} = {test_error=} / {ref_error=}. So that {test_error=} < max({atol=}, {ref_error=} * {tft=})')
pass
# print(f'For {tname}, Consider bump fudge_factor to {tft} = {test_error=} / {ref_error=}. So that {test_error=} < max({atol=}, {ref_error=} * {tft=})')
if return_target_fudge_factors:
return valid, max_adiff, tft
else:
Expand Down Expand Up @@ -351,7 +353,7 @@ def validate_with_reference(self, out, grads,
return out_allclose, out_adiff, [], []
grads_allclose = []
grads_adiff = []
print(f'using {self.fudge_factors=}')
# print(f'using {self.fudge_factors=}')
for grad, ref, lp_ref, fudge_factor, tname in zip(grads, self.dref_tensors, self.lp_dref_tensors, self.fudge_factors, self.TENSOR_NAMES):
allclose, adiff, tft = self._validate(grad,
ref,
Expand Down
1 change: 0 additions & 1 deletion tritonsrc/attn_torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
bwd_preprocess as bare_bwd_preprocess,
bwd_kernel_dk_dv as bare_bwd_kernel_dk_dv,
bwd_kernel_dq as bare_bwd_kernel_dq,
attn_bwd as bare_attn_bwd,
)
from tuned_bwd import (
tuned_bwd_kernel_dk_dv,
Expand Down
84 changes: 0 additions & 84 deletions tritonsrc/bwd_inner_dkdv.py

This file was deleted.

64 changes: 0 additions & 64 deletions tritonsrc/bwd_inner_dq.py

This file was deleted.

Loading

0 comments on commit f6b28a9

Please sign in to comment.