Skip to content

Commit

Permalink
minor changes to unpad_input test util func
Browse files Browse the repository at this point in the history
  • Loading branch information
ipiszy committed Sep 16, 2024
1 parent db80387 commit cdbbe84
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 14 deletions.
7 changes: 2 additions & 5 deletions flash_attn/bert_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,13 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
res = (
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
if unused_mask is not None:
return res + (used_seqlens_in_batch, )
else:
return res


def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_blocksparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad,
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
hidden_states = hidden_states[subset_mask]
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
Expand Down
314 changes: 314 additions & 0 deletions hopper/benchmark_attn.py.bak
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import time

try:
import cudnn
except ImportError:
cudnn = None


from einops import rearrange, repeat

# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3

# Need to install triton nightly:
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly

try:
from triton_fused_attention import attention as triton_attention
except ImportError:
triton_attention = None

def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)


def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")


def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
b, nheads, seqlen_q, headdim = q.shape
_, nheads_kv, seqlen_k, _ = k.shape
assert v.shape == (b, nheads_kv, seqlen_k, headdim)
assert cudnn is not None, 'CUDNN is not available'
q_gpu, k_gpu, v_gpu = q, k, v
o_gpu, stats_gpu = o, stats
graph_forward = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(q.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_forward = graph_forward.tensor_like(q_gpu.detach())
k_forward = graph_forward.tensor_like(k_gpu.detach())
v_forward = graph_forward.tensor_like(v_gpu.detach())

seqlens_reshaped = seqlens if varlen else None
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None

o_forward, stats_forward = graph_forward.sdpa(
name="sdpa",
q=q_forward,
k=k_forward,
v=v_forward,
is_inference=False,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
use_padding_mask=varlen,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
)

o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)

graph_forward.validate()
graph_forward.build_operation_graph()
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_forward.check_support()
graph_forward.build_plans()

variant_pack_forward = {
q_forward: q_gpu,
k_forward: k_gpu,
v_forward: v_gpu,
o_forward: o_gpu,
stats_forward: stats_gpu,
seq_len_q: seqlens_reshaped,
seq_len_kv: seqlens_reshaped,
}

dQ_gpu = torch.empty_like(q_gpu)
dK_gpu = torch.empty_like(k_gpu)
dV_gpu = torch.empty_like(v_gpu)
dO_gpu = grad

graph_backward = cudnn.pygraph(
io_data_type=cudnn.data_type.HALF,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)

q_backward = graph_backward.tensor_like(q_gpu.detach())
k_backward = graph_backward.tensor_like(k_gpu.detach())
v_backward = graph_backward.tensor_like(v_gpu.detach())
o_backward = graph_backward.tensor_like(o_gpu.detach())
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None

dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
name="sdpa_backward",
q=q_backward,
k=k_backward,
v=v_backward,
o=o_backward,
dO=dO_backward,
stats=stats_backward,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
use_padding_mask=varlen,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
)

dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())

graph_backward.validate()
graph_backward.build_operation_graph()
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_backward.check_support()
graph_backward.build_plans()

variant_pack_backward = {
q_backward: q_gpu,
k_backward: k_gpu,
v_backward: v_gpu,
o_backward: o_gpu,
dO_backward: dO_gpu,
stats_backward: stats_gpu,
dQ_backward: dQ_gpu,
dK_backward: dK_gpu,
dV_backward: dV_gpu,
seq_len_q: seqlens_reshaped,
seq_len_kv: seqlens_reshaped,
}

workspace = torch.empty(
max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
device="cuda", dtype=torch.uint8
)

def run_fwd(*args, **kwargs):
graph_forward.execute(variant_pack_forward, workspace)
return o_gpu, stats_gpu

def run_bwd(*args, **kwargs):
graph_backward.execute(variant_pack_backward, workspace)
return dQ_gpu, dK_gpu, dV_gpu

return run_fwd, run_bwd


torch.manual_seed(0)
repeats = 100
dropout_p = 0.0
causal = False
dtype = torch.float16
device = 'cuda'
verbose = False
batch_size = 2
# seqlen = 2048
seqlen = 8192
# seqlen = 4096
# seqlen = 2047
dim = 2048
# headdim = 128
# headdim = 64
headdim = 256

for mode in ['fwd', 'bwd']:
# for mode in ['bwd']:
for headdim in [64, 128, 256]:
# for headdim in [128]:
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
# for seqlen in [8192]:
nheads = dim // headdim
# nheads = 24
# headdim = 64
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
# nheads = 16
# headdim = 128
nheads_kv = nheads
# nheads_kv = 1

qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
grad_t = grad.transpose(1, 2).contiguous()
o_t = torch.empty_like(q.transpose(1, 2))
stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)

bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)

for causal in [False, True]:
# for causal in [True]:
print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
# For var-seq-len
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
if headdim <= 128 and cudnn is not None:
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
if mode == 'bwd':
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
if headdim <= 128:
if triton_attention is not None and nheads_kv == nheads:
if mode == 'fwd':
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
# TODO: fix Triton numeric errors.
# if mode == 'bwd':
# dv, v_t.grad = v_t.grad.clone(), None
# dk, k_t.grad = k_t.grad.clone(), None
# dq, q_t.grad = q_t.grad.clone(), None
# torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
# torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
# torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
if cudnn is not None:
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
if mode == 'fwd':
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
cudnn_sdpa_fwd()
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
cudnn_sdpa_fwd_varlen()
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
else:
cudnn_sdpa_fwd()
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
_, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
dq, dk, dv = cudnn_sdpa_bwd()
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
dq, dk, dv = cudnn_sdpa_bwd_varlen()
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
# pytorch_profiler(cudnn_sdpa, backward=False)

if headdim <= 128 or mode == 'fwd':
time.sleep(1)
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
time.sleep(1)
if mode == 'bwd':
dv, v.grad = v.grad.clone(), None
dk, k.grad = k.grad.clone(), None
dq, q.grad = q.grad.clone(), None
torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)

bench_var_fn = bench_fn
if mode == 'bwd':
grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
bench_var_fn = partial(benchmark_backward, grad=grad_var)
_, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')

# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
if headdim <= 128:
if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
if cudnn is not None:
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
if headdim <= 128 or mode == 'fwd':
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')

6 changes: 3 additions & 3 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def generate_qkv(
assert v.shape == (batch_size, seqlen_k, nheads_k, d)

if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
Expand All @@ -104,8 +104,8 @@ def generate_qkv(
)

if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
x_pt = x.detach().clone().requires_grad_()
lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask)
x_unpad_clone = x_unpad.clone()
x_unpad = x_unpad.requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
Expand Down
Loading

0 comments on commit cdbbe84

Please sign in to comment.