Skip to content

Commit

Permalink
[feat] cutlass FlashAttention bias+dropout support (#587)
Browse files Browse the repository at this point in the history
* [feat] cutlass FlashAttention bias+dropout support

adds attn bias (including bias grad) and dropout support to CUTLASS
flashattn implementation

[-------------------------------------------- attn --------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     12.7    |     7.5
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     15.5    |     9.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     12.7    |     7.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     15.6    |     9.1
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     10.1    |     6.0
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     12.7    |     7.5
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     44.3    |    29.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     55.0    |    35.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     45.1    |    29.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     55.6    |    35.3
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     37.0    |    22.6
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     46.8    |    29.0

Times are in milliseconds (ms).

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     19.3    |    24.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     19.4    |    24.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     22.3    |    28.7
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     22.4    |    29.0
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     19.5    |    22.7
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     19.5    |    23.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     62.7    |    91.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     63.4    |    93.7
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     74.8    |   109.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     75.1    |   111.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     63.2    |    85.5
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     64.0    |    90.1

* benchmark fixes

* add more conditions to reduce dOi @ Vj to 2 stages

BEFORE

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0)      |      2.8    |     2.4
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5)      |      2.8    |     3.3
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0)       |      3.4    |     3.2
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5)       |      3.4    |     4.2
      (8, 512, 64, 64, torch.float16, None, False, 0.0)                 |      2.8    |     2.0
      (8, 512, 64, 64, torch.float16, None, False, 0.5)                 |      2.8    |     2.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |      3.6    |     3.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |      3.6    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |      4.2    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |      4.2    |     5.6
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |      3.6    |     3.4
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |      3.6    |     4.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0)   |      9.7    |     8.8
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5)   |      9.7    |    12.6
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0)    |     12.0    |    12.1
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5)    |     12.1    |    16.1
      (8, 1024, 64, 64, torch.float16, None, False, 0.0)                |      9.7    |     7.4
      (8, 1024, 64, 64, torch.float16, None, False, 0.5)                |      9.7    |    10.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     11.3    |    14.0
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     11.3    |    17.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     13.6    |    17.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     13.6    |    20.9
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     11.3    |    12.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     11.3    |    15.8

AFTER

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0)      |      2.8    |     2.4
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5)      |      2.8    |     3.0
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0)       |      3.4    |     3.2
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5)       |      3.4    |     3.8
      (8, 512, 64, 64, torch.float16, None, False, 0.0)                 |      2.8    |     2.0
      (8, 512, 64, 64, torch.float16, None, False, 0.5)                 |      2.8    |     2.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |      3.6    |     3.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |      3.6    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |      4.2    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |      4.2    |     5.6
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |      3.6    |     3.4
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |      3.6    |     4.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0)   |      9.7    |     8.8
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5)   |      9.7    |    11.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0)    |     12.0    |    12.1
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5)    |     12.1    |    14.6
      (8, 1024, 64, 64, torch.float16, None, False, 0.0)                |      9.7    |     7.4
      (8, 1024, 64, 64, torch.float16, None, False, 0.5)                |      9.7    |     9.6
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     11.3    |    14.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     11.3    |    17.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     13.6    |    17.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     13.6    |    20.9
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     11.3    |    12.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     11.3    |    15.8

* fix mypy error

* fix windows build

* rename cutlass rand uniform file name

* black reformat
  • Loading branch information
jfc4050 authored Jan 18, 2023
1 parent 6f3c20f commit 814314d
Show file tree
Hide file tree
Showing 54 changed files with 1,768 additions and 255 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
build/
dist/

# for autocomplete
compile_commands.json

# Pytest verbose output
test-results/

Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- fMHA: Added CUTLASS-based kernel for `xformers.ops.memory_efficient_attention`. This kernel is automatically depending on the inputs, and works on any GPU after P100 [facebookresearch/xformers#362]

## [0.0.15] - 2022-12-13
### Fixed

### Added
- Added tensor attn bias support to CUTLASS FlashAttention
- Added tensor attn bias grad support to CUTLASS FlashAttention
- Added dropout support to CUTLASS FlashAttention

## [0.0.12] - 2022-08-08
### Fixed
- Removed duplicated biases in the FusedMLP layers [facebookresearch/xformers#317]
Expand Down
90 changes: 70 additions & 20 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
)

_out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad(
query, key, value
query, key, value, op=op
)
ref_lse = ((query.float() / k**0.5) @ key.float().transpose(-2, -1)).logsumexp(-1)

Expand All @@ -616,7 +616,13 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):

@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@pytest.mark.parametrize(
"attn_bias_type", [None, xformers.ops.LowerTriangularMask, torch.Tensor]
"attn_bias_cfg", # (type(bias), bias.requires_grad)
[
(None, False),
(xformers.ops.LowerTriangularMask, False),
(torch.Tensor, True),
(torch.Tensor, False),
],
)
@pytest.mark.parametrize("grad_out_contiguous", [False, True])
@pytest.mark.parametrize(
Expand All @@ -627,9 +633,10 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
def test_backward(
op_device_dtype_B_Mq_Mkv_H_K_Kv,
grad_out_contiguous,
attn_bias_type,
attn_bias_cfg,
fmt,
):
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
(
op_bw,
device,
Expand All @@ -646,9 +653,13 @@ def test_backward(
attn_bias_type=attn_bias_type,
fmt=fmt,
)
op_fw = sample_random_supported_fw(
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
seed=q_len * kv + kv_len * k,
op_fw = (
sample_random_supported_fw(
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
seed=q_len * kv + kv_len * k,
)
if op_bw != fmha.cutlass.BwOp
else fmha.cutlass.FwOp
)
qkv = None

Expand All @@ -666,6 +677,11 @@ def test_backward(
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
if isinstance(attn_bias, torch.Tensor):
attn_bias.requires_grad_(attn_bias_requires_grad)

if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)):
pytest.skip("inputs not supported")

out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=(op_fw, op_bw)
Expand All @@ -692,6 +708,9 @@ def test_backward(
else:
grads = [qkv.grad]
qkv.grad = None
if attn_bias_requires_grad:
grads.append(attn_bias.grad)
attn_bias.grad = None

ref = ref_attention(query, key, value, attn_bias)
ref.backward(grad_out)
Expand All @@ -713,6 +732,12 @@ def test_backward(
assert isinstance(qkv.grad, torch.Tensor)
grads_ref = [qkv.grad]
grads_name = ["qkv"]

if attn_bias_requires_grad:
assert isinstance(attn_bias.grad, torch.Tensor)
grads_ref.append(attn_bias.grad)
grads_name.append("bias")

del query
del key
del value
Expand Down Expand Up @@ -755,49 +780,64 @@ def _vec_binom_test(x, n, p):
return pval


def _get_drop_mask(op, batch_size, q_len, kv_len, p, device):
if op == fmha.cutlass.FwOp:
mask = torch.empty((batch_size, 1, q_len, kv_len), device=device)
rand_uniform = torch.ops.xformers._cutlass_rand_uniform(p, mask)
mask = (rand_uniform > p).to(torch.float32)
mask = mask.reshape(batch_size, q_len, kv_len)
else:
mask = torch.empty((batch_size, q_len, kv_len), device=device)
mask = torch.ops.xformers._temp_dropout(mask, p)

return mask


@cuda_only
@pytest.mark.parametrize("seed", [42, 124])
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
@pytest.mark.parametrize("device", ["cuda"])
def test_dropout(device, q_len, kv_len, batch_size, k_len, p, seed):
@pytest.mark.parametrize("op", ALL_FW_OPS, ids=list(map(lambda t: t.NAME, ALL_FW_OPS)))
def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed):
device = "cuda"
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

attn_bias = None
op = (fmha.small_k.FwOp, None)

inputs_for_support_check = fmha.Inputs(query, key, value, attn_bias, p, None)
if not op.supports(inputs_for_support_check):
del query, key, value, attn_bias
pytest.skip(f"{op.NAME}: unsupported input")

torch.manual_seed(seed)
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, p, op=op
query, key, value, attn_bias, p, op=(op, None)
)

torch.manual_seed(seed)
out2 = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, p, op=op
query, key, value, attn_bias, p, op=(op, None)
)

assert_allclose(out, out2)

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)

mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
ref = ref_attention(query, key, value, attn_bias, mask, p)
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"

num_trials = 1000
p_val_tol = 0.0001
p_val_tol = 1e-6
keep_prob = 1 - p
masks = []
for i in range(num_trials):
mask = torch.ops.xformers._temp_dropout(mask, p)
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
masks.append(mask.clone().cpu())
masks = torch.stack(masks, dim=0)
p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob)
Expand Down Expand Up @@ -840,10 +880,8 @@ def _test_dropout_backward(q_len, kv_len, batch_size, k_len, p, op, dtype):
key.grad = None
value.grad = None

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)

ref = ref_attention(query, key, value, None, mask, p)
ref.backward(grad_out)
Expand Down Expand Up @@ -881,6 +919,18 @@ def test_dropout_backward_flash(q_len, kv_len, batch_size, k_len, p):
)


@cuda_only
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k_len", [16, 32])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
def test_dropout_backward_cutlass(q_len, kv_len, batch_size, k_len, p):
_test_dropout_backward(
q_len, kv_len, batch_size, k_len, p, op=fmha.cutlass.FwOp, dtype=torch.float16
)


@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("kv_len", [3 * 32])
Expand Down
Loading

0 comments on commit 814314d

Please sign in to comment.