forked from littsk/test_attn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sdpa_cudnn.py
139 lines (114 loc) · 4.16 KB
/
test_sdpa_cudnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from collections import namedtuple
from functools import partial
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
import triton.profiler as proton
from hopper.flash_attn_interface import flash_attn_func as flash_attn_func_hopper
# Set random seed for reproducibility
torch.manual_seed(42)
torch.set_default_device("cuda")
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
def compiled_sdpa():
return torch.compile(
scaled_dot_product_attention,
fullgraph=True,
backend="inductor",
mode="max-autotune",
)
def warmup(iters: int, backend, f, *args, **kwargs) -> None:
for _ in range(iters):
with sdpa_kernel(backends=[backend]):
_ = f(*args, **kwargs)
"""
Flux q, k, v shapes:
q = [1, 24, 4608, 128]
k = [1, 24, 4608, 128]
v = [1, 24, 4608, 128]
FA -> (batch, seqlen, nheads, headdim)
Torch sdpa expects -> (batch, nheads, seqlen, headdim)
ref:
torch.functional: https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/nn/functional.py#L5617
"""
device = torch.device("cuda")
dtype = torch.bfloat16
is_causal = False
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)
batch = 1
num_heads = 24
head_dim = 128
seq_len = 4608
flops = 4 * batch * seq_len**2 * num_heads * head_dim // (2 if is_causal else 1)
warmup_iter = 10
q_shape = SdpaShape(batch, num_heads, seq_len, head_dim)
k_shape = SdpaShape(batch, num_heads, seq_len, head_dim)
v_shape = SdpaShape(batch, num_heads, seq_len, head_dim)
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
assert query.shape == q_shape
assert key.shape == k_shape
assert value.shape == v_shape
# warmup for sdpa_kernel
warmup(
warmup_iter,
SDPBackend.CUDNN_ATTENTION,
torch.nn.functional.scaled_dot_product_attention,
query,
key,
value,
is_causal=is_causal,
)
proton.start()
# cuDNN attention
with proton.scope("torch_scaled_dot_product_cudnn_attention", metrics={"flops": flops}):
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
attn_out_sdpa_cudnn = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal
)
# torch aten explict cuDNN op call
for _ in range(warmup_iter):
_ = torch.ops.aten._scaled_dot_product_cudnn_attention(
query,
key,
value,
attn_bias=None,
compute_log_sumexp=False,
dropout_p=0.0,
is_causal=is_causal,
)
with proton.scope(
"torch.ops.aten._scaled_dot_product_cudnn_attention", metrics={"flops": flops}
):
attn_out_aten_sdpa_cudnn = torch.ops.aten._scaled_dot_product_cudnn_attention(
query,
key,
value,
attn_bias=None,
compute_log_sumexp=False,
dropout_p=0.0,
is_causal=is_causal,
)
# torch sdpa native
for _ in range(warmup_iter):
_ = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal
)
with proton.scope("torch_scaled_dot_product_attention", metrics={"flops": flops}):
attn_out_sdpa = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal
)
# torch compiled sdpa
for _ in range(warmup_iter):
_ = compiled_sdpa()(query, key, value, is_causal=is_causal)
with proton.scope("torch_scaled_dot_product_attention_compiled", metrics={"flops": flops}):
flash_attention_compiled_op = compiled_sdpa()
attn_out_sdpa_compiled = flash_attention_compiled_op(query, key, value, is_causal=is_causal)
# FlashAttention-3 Hopper
query = query.permute(0, 2, 1, 3) # B, H, S, D
key = key.permute(0, 2, 1, 3) # B, H, S, D
value = value.permute(0, 2, 1, 3) # B, H, S, D
for _ in range(warmup_iter):
_, _ = flash_attn_func_hopper(query, key, value)
with proton.scope("flash_attention_hopper", metrics={"flops": flops}):
attn_output_fa_hopper, _ = flash_attn_func_hopper(query, key, value, causal=is_causal)
proton.finalize()