An autotuner for the flash version of attention and retnet
This project provide an autotuner for the flash version of attention and retnet. Users can use it as an pytorch func.
# Attention on A100
from ops.attention_interface import flash_attn_func
import torch
from arch import A100
device_type = A100()
dtype = torch.float16
device = torch.device("cuda")
q = torch.randn(batch, heads, seqlen_q, dim_qk, device=device, dtype=dtype)
k = torch.randn(batch, heads, seqlen_kv, dim_qk, device=device, dtype=dtype)
v = torch.randn(batch, heads, seqlen_kv, dim_v, device=device, dtype=dtype)
o = flash_attn_func(q,k,v,device_type)
# retnet on RTX4090
from ops.retnet_interface import RetNetAttnFunc
import torch
device_type = RTX4090()
dtype = torch.float16
device = torch.device("cuda")
q = torch.randn(batch, heads, seqlen_q, dim_qk, device=device, dtype=dtype)
k = torch.randn(batch, heads, seqlen_kv, dim_qk, device=device, dtype=dtype)
v = torch.randn(batch, heads, seqlen_kv, dim_v, device=device, dtype=dtype)
mask = torch.randn(heads, seqlen_q, seqlen_kv, device=device, dtype=dtype)
o = RetNetAttnFunc(q, k, v, mask, device_type)
do = torch.randn(batch, heads, seqlen_q, dim_v, device=device, dtype=dtype)
o.backward(do)
- cuda 12.3
- cmake 3.24
- clone this repo and its submodule cutlass
git clone --recursive https://github.com/smallscientist1/attention_autotuner.git
- add to PYTHONPATH
export PYTHONPATH=$PYTHONPATH:/path/to/attention_autotuner/python
- build the C++ benchmark on nvidia Ampere GPU(e.g. A100)
cd benchmarks
mkdir build
cd build
cmake -DPROJECT_CUDA_ARCH="80" ..
- q @ k
- reduce_max(qk)
- scale = exp(m_old-m_new)
- lse* scale
- acco * scale
- accs * exp(accs-m_new)
- lse = reduce_sum(accs)
- q @ k
- qk * mask
- reduce_abs(qk)
- clamp(r)
- scale = r_old/r_new
- acco * scale
- accs / r_new
- chunkwise retnet
- cost model
- autotuner(more general policy for retnet)
- elementwise op
- attention backward
- retnet performance issue(added load q once, mask stage 2?)
- causal config
- retnet parallel scan version seqlen_q != seqlen_kv
- retnet parallel scan template
- retnet bwd load_q_once,causal
- the performance of python interface?
- retnet backward num_stage_qk=2 bug