diff --git a/.gitignore b/.gitignore index d2674f638b..38b453363b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ build/ dist/ +# for autocomplete +compile_commands.json + # Pytest verbose output test-results/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 14ba6cd9cf..975eca67ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - fMHA: Added CUTLASS-based kernel for `xformers.ops.memory_efficient_attention`. This kernel is automatically depending on the inputs, and works on any GPU after P100 [facebookresearch/xformers#362] +## [0.0.15] - 2022-12-13 +### Fixed + +### Added +- Added tensor attn bias support to CUTLASS FlashAttention +- Added tensor attn bias grad support to CUTLASS FlashAttention +- Added dropout support to CUTLASS FlashAttention + ## [0.0.12] - 2022-08-08 ### Fixed - Removed duplicated biases in the FusedMLP layers [facebookresearch/xformers#317] diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 0fb869bf62..1ebcbd9fc1 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -607,7 +607,7 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv): ) _out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( - query, key, value + query, key, value, op=op ) ref_lse = ((query.float() / k**0.5) @ key.float().transpose(-2, -1)).logsumexp(-1) @@ -616,7 +616,13 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv): @pytest.mark.parametrize("fmt", ["BMK", "BMHK"]) @pytest.mark.parametrize( - "attn_bias_type", [None, xformers.ops.LowerTriangularMask, torch.Tensor] + "attn_bias_cfg", # (type(bias), bias.requires_grad) + [ + (None, False), + (xformers.ops.LowerTriangularMask, False), + (torch.Tensor, True), + (torch.Tensor, False), + ], ) @pytest.mark.parametrize("grad_out_contiguous", [False, True]) @pytest.mark.parametrize( @@ -627,9 +633,10 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv): def test_backward( op_device_dtype_B_Mq_Mkv_H_K_Kv, grad_out_contiguous, - attn_bias_type, + attn_bias_cfg, fmt, ): + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg ( op_bw, device, @@ -646,9 +653,13 @@ def test_backward( attn_bias_type=attn_bias_type, fmt=fmt, ) - op_fw = sample_random_supported_fw( - fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias), - seed=q_len * kv + kv_len * k, + op_fw = ( + sample_random_supported_fw( + fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias), + seed=q_len * kv + kv_len * k, + ) + if op_bw != fmha.cutlass.BwOp + else fmha.cutlass.FwOp ) qkv = None @@ -666,6 +677,11 @@ def test_backward( query.requires_grad_(True) key.requires_grad_(True) value.requires_grad_(True) + if isinstance(attn_bias, torch.Tensor): + attn_bias.requires_grad_(attn_bias_requires_grad) + + if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)): + pytest.skip("inputs not supported") out = xformers.ops.memory_efficient_attention( query, key, value, attn_bias, op=(op_fw, op_bw) @@ -692,6 +708,9 @@ def test_backward( else: grads = [qkv.grad] qkv.grad = None + if attn_bias_requires_grad: + grads.append(attn_bias.grad) + attn_bias.grad = None ref = ref_attention(query, key, value, attn_bias) ref.backward(grad_out) @@ -713,6 +732,12 @@ def test_backward( assert isinstance(qkv.grad, torch.Tensor) grads_ref = [qkv.grad] grads_name = ["qkv"] + + if attn_bias_requires_grad: + assert isinstance(attn_bias.grad, torch.Tensor) + grads_ref.append(attn_bias.grad) + grads_name.append("bias") + del query del key del value @@ -755,6 +780,19 @@ def _vec_binom_test(x, n, p): return pval +def _get_drop_mask(op, batch_size, q_len, kv_len, p, device): + if op == fmha.cutlass.FwOp: + mask = torch.empty((batch_size, 1, q_len, kv_len), device=device) + rand_uniform = torch.ops.xformers._cutlass_rand_uniform(p, mask) + mask = (rand_uniform > p).to(torch.float32) + mask = mask.reshape(batch_size, q_len, kv_len) + else: + mask = torch.empty((batch_size, q_len, kv_len), device=device) + mask = torch.ops.xformers._temp_dropout(mask, p) + + return mask + + @cuda_only @pytest.mark.parametrize("seed", [42, 124]) @pytest.mark.parametrize("p", [0.3, 0.7]) @@ -762,42 +800,44 @@ def _vec_binom_test(x, n, p): @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) @pytest.mark.parametrize("q_len", [2, 33]) -@pytest.mark.parametrize("device", ["cuda"]) -def test_dropout(device, q_len, kv_len, batch_size, k_len, p, seed): +@pytest.mark.parametrize("op", ALL_FW_OPS, ids=list(map(lambda t: t.NAME, ALL_FW_OPS))) +def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed): + device = "cuda" scale = 3 query = torch.randn((batch_size, q_len, k_len), device=device) * scale key = torch.randn((batch_size, kv_len, k_len), device=device) * scale value = torch.randn((batch_size, kv_len, k_len), device=device) * scale attn_bias = None - op = (fmha.small_k.FwOp, None) + + inputs_for_support_check = fmha.Inputs(query, key, value, attn_bias, p, None) + if not op.supports(inputs_for_support_check): + del query, key, value, attn_bias + pytest.skip(f"{op.NAME}: unsupported input") torch.manual_seed(seed) out = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias, p, op=op + query, key, value, attn_bias, p, op=(op, None) ) torch.manual_seed(seed) out2 = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias, p, op=op + query, key, value, attn_bias, p, op=(op, None) ) assert_allclose(out, out2) - mask = torch.empty((batch_size, q_len, kv_len), device=device) - torch.manual_seed(seed) - mask = torch.ops.xformers._temp_dropout(mask, p) - + mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) ref = ref_attention(query, key, value, attn_bias, mask, p) assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}" num_trials = 1000 - p_val_tol = 0.0001 + p_val_tol = 1e-6 keep_prob = 1 - p masks = [] for i in range(num_trials): - mask = torch.ops.xformers._temp_dropout(mask, p) + mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) masks.append(mask.clone().cpu()) masks = torch.stack(masks, dim=0) p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob) @@ -840,10 +880,8 @@ def _test_dropout_backward(q_len, kv_len, batch_size, k_len, p, op, dtype): key.grad = None value.grad = None - mask = torch.empty((batch_size, q_len, kv_len), device=device) - torch.manual_seed(seed) - mask = torch.ops.xformers._temp_dropout(mask, p) + mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) ref = ref_attention(query, key, value, None, mask, p) ref.backward(grad_out) @@ -881,6 +919,18 @@ def test_dropout_backward_flash(q_len, kv_len, batch_size, k_len, p): ) +@cuda_only +@pytest.mark.parametrize("p", [0.3, 0.7]) +@pytest.mark.parametrize("k_len", [16, 32]) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 33]) +def test_dropout_backward_cutlass(q_len, kv_len, batch_size, k_len, p): + _test_dropout_backward( + q_len, kv_len, batch_size, k_len, p, op=fmha.cutlass.FwOp, dtype=torch.float16 + ) + + @pytest.mark.parametrize("k_len", [32]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("kv_len", [3 * 32]) diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index ed27e25a97..608578efc4 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -7,7 +7,7 @@ import itertools import math from functools import partial -from typing import cast +from typing import Any, cast import torch from torch.utils import benchmark @@ -21,7 +21,14 @@ def create_attn_bias( - bias_type, batch_size: int, num_heads: int, q_len: int, kv_len: int, device, dtype + bias_type, + batch_size: int, + num_heads: int, + q_len: int, + kv_len: int, + device, + dtype, + bias_requires_grad: bool = False, ): NoneType = type(None) if bias_type is NoneType: @@ -110,7 +117,7 @@ def T(t): ] -p = 0.0 +seed = 0 FORCE_OP = None # FORCE_OP = xformers.ops.MemoryEfficientAttentionOp # FORCE_OP = xformers.ops.MemoryEfficientAttentionCutlassOp @@ -131,7 +138,13 @@ def product_dict(**kwargs): product_dict( shape=SHAPES, num_threads=NUM_THREADS, - attn_bias_type=[type(None), torch.Tensor, xformers.ops.LowerTriangularMask], + dropout_p=[0.0, 0.3], + attn_bias_cfg=[ + (type(None), False), + (torch.Tensor, False), + (torch.Tensor, True), + (xformers.ops.LowerTriangularMask, False), + ], dtype=[torch.half, torch.bfloat16, torch.float], ) ) @@ -146,19 +159,11 @@ def create_tensors(shape, dtype, requires_grad=False): return qkv, q, k, v -def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): +def benchmark_forward(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype): B, M, H, K = shape _, q, k, v = create_tensors(shape, dtype) - - inp = fmha.Inputs(query=q, key=k, value=v) - try: - op = (fmha._dispatch_fw(inp), None) if FORCE_OP is None else FORCE_OP - except NotImplementedError: - return - if not op[0].supports(inp): - return - - inp.attn_bias = create_attn_bias( + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg + bias = create_attn_bias( attn_bias_type, batch_size=B, num_heads=H, @@ -166,7 +171,15 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): kv_len=M, device=device, dtype=dtype, + bias_requires_grad=attn_bias_requires_grad, ) + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + + try: + op = (fmha._dispatch_fw(inp), None) if FORCE_OP is None else FORCE_OP + except NotImplementedError: + return + if not op[0].supports(inp): return @@ -175,7 +188,10 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): torch.half: "f16", torch.float: "f32", }[dtype] - sub_label = f"{dtype_str} B={B}, M={M}, H={H}, K={K}" + sub_label = ( + f"{dtype_str} B={B}, M={M}, H={H}, K={K}, p={dropout_p}, " + f" BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}" + ) try: r = xformers.ops.memory_efficient_attention( @@ -186,7 +202,12 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): k.float(), v.float(), inp.attn_bias, + inp.p, ) + + assert not ( + inp.p > 0 and CHECK_CORRECTNESS + ), "correctness checking not yet implemented for dropout" assert not CHECK_CORRECTNESS or (r - rr).abs().max() < 4e-3, ( (r - rr).abs().max() ) @@ -201,7 +222,7 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): "k": k, "v": v, "attn_bias": inp.attn_bias, - "p": p, + "p": dropout_p, "fn": partial(xformers.ops.memory_efficient_attention, op=op), }, label=f"attention (attn_bias={attn_bias_type})", @@ -216,7 +237,7 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): "k": k, "v": v, "attn_bias": inp.attn_bias, - "p": p, + "p": dropout_p, "fn": ref_attention, }, label=f"attention (attn_bias={attn_bias_type})", @@ -226,23 +247,12 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): ) -def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype): +def benchmark_backward(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype): B, M, H, K = shape qkv, q, k, v = create_tensors(shape, dtype, requires_grad=True) - inp = fmha.Inputs(query=q, key=k, value=v) - try: - op = ( - (fmha._dispatch_fw(inp), fmha._dispatch_bw(inp)) - if FORCE_OP is None - else FORCE_OP - ) - except NotImplementedError: - return - if not op[0].supports(inp) or not op[1].supports(inp): - return - - inp.attn_bias = create_attn_bias( + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg + bias = create_attn_bias( attn_bias_type, batch_size=B, num_heads=H, @@ -250,8 +260,25 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype): kv_len=M, device=device, dtype=dtype, + bias_requires_grad=attn_bias_requires_grad, ) - if not op[0].supports(inp) or not op[1].supports(inp): + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + try: + if FORCE_OP: + op = FORCE_OP + else: + op_fw: Any = None + op_bw = fmha._dispatch_bw(inp) + if op_bw == fmha.flash.BwOp: + op_fw = fmha.flash.FwOp + elif op_bw == fmha.cutlass.BwOp: + op_fw = fmha.cutlass.FwOp + else: + op_fw = fmha._dispatch_fw(inp) + op = (op_fw, op_bw) + except NotImplementedError: + return + if not (op[0].supports(inp) and op[1].supports(inp)): return dtype_str = { @@ -259,9 +286,14 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype): torch.half: "f16", torch.float: "f32", }[dtype] - sub_label = f"{dtype_str} B={B}, M={M}, H={H}, K={K}" + sub_label = ( + f"{dtype_str} B={B}, M={M}, H={H}, K={K}, p={dropout_p}," + f" BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}" + ) - out = xformers.ops.memory_efficient_attention(q, k, v, inp.attn_bias, p, op=op) + out = xformers.ops.memory_efficient_attention( + inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=op + ) grad_benchmark = torch.ones_like(q) yield benchmark.Timer( @@ -279,15 +311,20 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype): try: qkv.grad = None - r = xformers.ops.memory_efficient_attention(q, k, v, inp.attn_bias, op=op) + r = xformers.ops.memory_efficient_attention( + q, k, v, inp.attn_bias, dropout_p, op=op + ) r.backward(torch.ones_like(q)) grad = cast(torch.Tensor, qkv.grad) qkv.grad = None - rr = ref_attention(q, k, v, inp.attn_bias) + rr = ref_attention(q, k, v, inp.attn_bias, dropout_p) rr.backward(torch.ones_like(q)) atol = 2e-4 + 2e-6 * K * M * math.sqrt(B) * math.sqrt(M) + assert not ( + dropout_p > 0 and CHECK_CORRECTNESS + ), "correctness checking not yet implemented for dropout" # type: ignore assert ( not CHECK_CORRECTNESS or (grad - qkv.grad).abs().max() < atol @@ -298,7 +335,7 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype): yield benchmark.Timer( stmt="out.backward(grad, retain_graph=True)", globals={ - "out": ref_attention(q, k, v, inp.attn_bias), + "out": ref_attention(q, k, v, inp.attn_bias, dropout_p), "grad": grad_benchmark, }, label=f"attention backward (attn_bias={attn_bias_type})", diff --git a/xformers/csrc/attention/attention.cpp b/xformers/csrc/attention/attention.cpp index 90ca6df5d7..8b01485e73 100644 --- a/xformers/csrc/attention/attention.cpp +++ b/xformers/csrc/attention/attention.cpp @@ -17,11 +17,13 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_logsumexp, bool causal, float? scale) -> (Tensor, Tensor)")); + "xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, bool compute_logsumexp, bool causal, float? scale) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_backward_cutlass(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, bool causal, float? scale) -> (Tensor, Tensor, Tensor)")); + "xformers::efficient_attention_backward_cutlass(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, bool causal, float? scale) -> (Tensor, Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::_temp_dropout(Tensor out, float p) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::_cutlass_rand_uniform(float p, Tensor out) -> Tensor")); } diff --git a/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu b/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu index 9478fea46b..7c356945da 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu +++ b/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu @@ -1,12 +1,18 @@ +#include + #include #include #include #include #include +#include #include #include +#include "ATen/ops/empty_like.h" +#include "gemm_kernel_utils.h" #include "kernel_backward.h" +#include "pytorch_utils.h" #define DISPATCH_MAXK(func) \ { \ @@ -23,44 +29,62 @@ } \ } -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_MAXK(([&] { \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = \ - AttentionBackwardKernel; \ - bool isAligned = \ - (QUERY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ - KEY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ - VALUE.stride(2) % AlignedAK::kOptimalAlignement == 0); \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionBackwardKernel< \ - ArchTag, \ - scalar_t, \ - kIsAligned, \ - kMaxK>; \ - FUNC(); \ - })) \ - })) \ - })) \ - })); \ +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, USE_DROPOUT, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_MAXK(([&] { \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_BOOL( \ + USE_DROPOUT, kApplyDropout, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + true, \ + kApplyDropout, \ + kMaxK>; \ + bool isAligned = \ + (QUERY.stride(2) % \ + AlignedAK::kOptimalAlignement == \ + 0 && \ + KEY.stride(2) % AlignedAK::kOptimalAlignement == \ + 0 && \ + VALUE.stride(2) % \ + AlignedAK::kOptimalAlignement == \ + 0); \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = \ + AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + kIsAligned, \ + kApplyDropout, \ + kMaxK>; \ + FUNC(); \ + })) \ + })) \ + })) \ + })) \ + })); \ } namespace { -std::tuple +std::tuple mem_efficient_attention_backward_cutlass( const at::Tensor& grad_out_, const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, + const c10::optional& bias, // additive attention bias const at::Tensor& logsumexp, const at::Tensor& out, + double dropout_p, // dropout probability + int64_t rng_seed, // seed using for generating random numbers for dropout + int64_t rng_offset, // offset into random number sequence bool causal, const c10::optional scale) { #ifdef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD @@ -120,7 +144,9 @@ mem_efficient_attention_backward_cutlass( // keys with no query associated, so they are not // initialized bool grad_kv_needs_init = causal && N > M; - at::Tensor grad_q, grad_k, grad_v; + const bool bias_requires_grad = bias.has_value() && bias->requires_grad(); + + at::Tensor grad_q, grad_k, grad_v, grad_bias; if (!grad_kv_needs_init && query.size(1) == key.size(1) && query.size(3) == value.size(3) && query.storage().is_alias_of(key.storage()) && @@ -141,8 +167,14 @@ mem_efficient_attention_backward_cutlass( grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options()) : at::empty(value.sizes(), value.options()); } + if (bias_requires_grad) { + grad_bias = at::empty(bias->sizes(), bias->options()); + } at::Tensor workspace; + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + at::PhiloxCudaState rng_engine_inputs(rng_seed, rng_offset); + auto launchKernel = [&](auto _k, int computeCapability) { using Kernel = decltype(_k); using scalar_t = typename Kernel::scalar_t; @@ -215,6 +247,41 @@ mem_efficient_attention_backward_cutlass( ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + if (bias.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias)); + + p.bias_ptr = (scalar_t*)bias->data_ptr(); + + // assign strides for bias, viewed as: + // (batch_sz, n_heads, n_queries, n_keys) + const at::Tensor bias_4d_view = get_bias_4d_view(*bias, B, nH, M, N); + ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias_4d_view.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias_4d_view.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias_4d_view.stride(2)); + + if (bias_requires_grad) { + p.grad_bias_ptr = (scalar_t*)grad_bias.data_ptr(); + + // assign strides for gB, viewed as + // (batch_sz, n_heads, n_queries, n_keys). might have different strides + // than B, for example if bias tensor was created with + // torch.tensor((B * nH, 1, nK)).expand((B * nH, nQ, nK)), + // different values of Q will point to the same memory + // locations, meaning bias.stride(1) == 0, while we'd want + // grad_bias.stride(1) == nK + const at::Tensor grad_bias_4d_view = + get_bias_4d_view(grad_bias, B, nH, M, N); + ASSIGN_CHECK_OVERFLOW(p.gB_strideB, grad_bias_4d_view.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gB_strideH, grad_bias_4d_view.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gB_strideM, grad_bias_4d_view.stride(2)); + } + } + + if (use_dropout) { + p.rng_engine_inputs = rng_engine_inputs; + p.dropout_prob = dropout_p; + } + int64_t size_bytes = p.workspace_size(); if (size_bytes) { workspace = @@ -256,10 +323,11 @@ mem_efficient_attention_backward_cutlass( kernel_fn<<>>(p); }; - DISPATCH_KERNEL( - query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); + DISPATCH_KERNEL(query, key, value, use_dropout, ([&] { + launchKernel(Kernel{}, computeCapability); + })); AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_q, grad_k, grad_v); + return std::make_tuple(grad_q, grad_k, grad_v, grad_bias); #endif } // namespace diff --git a/xformers/csrc/attention/cuda/fmha/attention_cutlass_rand_uniform.cu b/xformers/csrc/attention/cuda/fmha/attention_cutlass_rand_uniform.cu new file mode 100644 index 0000000000..b13769d67c --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/attention_cutlass_rand_uniform.cu @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace { + +/** + * simple kernel that populates a tensor with rand uniform values. + * currently only used for testing purposes, not much attention + * is paid to performance. + * + * problem is partitioned as follows: + * - (batch, head) is given by block coordinates + * - each thread handles a row for a given (batch, head) + */ +template +__global__ void rand_uniform_kernel( + int64_t n_heads, + int64_t n_queries, + int64_t n_keys, + float dropout_prob, + at::PhiloxCudaState rng_engine_inputs, + mask_t* mask_out, + int64_t mask_numel) { + const int64_t batch_id = blockIdx.x; + const int64_t head_id = blockIdx.y; + const int64_t query_idx = threadIdx.x; + + const auto seeds = at::cuda::philox::unpack(rng_engine_inputs); + + const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) + + head_id * (n_queries * n_keys); + + curandStatePhilox4_32_10_t curand_state; + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + dropout_seq_start + query_idx * n_keys, + &curand_state); + + for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) { + float4 rand_quad = curand_uniform4(&curand_state); + +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int64_t linear_idx = batch_id * (n_heads * n_queries * n_keys) + + head_id * (n_queries * n_keys) + query_idx * n_keys + key_start_idx + + i; + + if (linear_idx < mask_numel) { + mask_out[linear_idx] = (&rand_quad.x)[i]; + } + } + } +} + +/** + * fill tensor with random uniform values. only used for testing, not much + * attention is paid to performance + */ +at::Tensor rand_uniform(double p, at::Tensor out) { + const int64_t batch_sz = out.size(0); + const int64_t n_heads = out.size(1); + const int64_t n_queries = out.size(2); + const int64_t n_keys = out.size(3); + + at::CUDAGeneratorImpl* gen = + at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = + gen->philox_cuda_state(batch_sz * n_heads * n_queries * n_keys); + } + + rand_uniform_kernel<<>>( + n_heads, + n_queries, + n_keys, + p, + rng_engine_inputs, + reinterpret_cast(out.data_ptr()), + out.numel()); + + return out; +} + +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::_cutlass_rand_uniform"), + TORCH_FN(rand_uniform)); +} diff --git a/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu b/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu index 6e25a11d62..7fecc4e47e 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu +++ b/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu @@ -1,11 +1,19 @@ +#include +#include + #include #include #include +#include #include +#include #include +#include #include +#include #include "kernel_forward.h" +#include "pytorch_utils.h" #define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ { \ @@ -125,10 +133,12 @@ struct TypeTraits { (Mode BMHK) With all the heads having the same seqlen (Mode 1MHK) `batch=1` with all tokens across batches concatenated */ -std::tuple efficient_attention_forward_cutlass( +std::tuple +efficient_attention_forward_cutlass( const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] const at::Tensor& value, // [b, seqlen, num_heads, Kv] + const c10::optional& bias, // [b, num_heads, seqlen, seqlen] // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the // position of the first query token for batch $b const c10::optional& cu_seqlens_q, @@ -137,6 +147,7 @@ std::tuple efficient_attention_forward_cutlass( const c10::optional& cu_seqlens_k, // (Mode 1MHK only) Maximum sequence length across batches const c10::optional max_seqlen_q_, + double dropout_p, // attention matrix dropout probability bool compute_logsumexp, bool causal, c10::optional scale) { @@ -201,6 +212,19 @@ std::tuple efficient_attention_forward_cutlass( at::Tensor res; at::Tensor logsumexp; + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + at::PhiloxCudaState rng_engine_inputs; + if (use_dropout) { + at::CUDAGeneratorImpl* gen = + at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + rng_engine_inputs = gen->philox_cuda_state(B * num_heads * M * N); + } + auto launchKernel = [&](auto _k, int computeCapability) { using Kernel = decltype(_k); using scalar_t = typename Kernel::scalar_t; @@ -268,6 +292,25 @@ std::tuple efficient_attention_forward_cutlass( ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + if (bias.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias)); + p.attn_bias_ptr = (scalar_t*)bias->data_ptr(); + + // assign strides for bias, viewed as + // (batch_sz, n_heads, n_queries, n_keys) + const at::Tensor bias_4d_view = + get_bias_4d_view(*bias, B, num_heads, M, N); + ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias_4d_view.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias_4d_view.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias_4d_view.stride(2)); + } + + p.use_dropout = use_dropout; + if (p.use_dropout) { + p.rng_engine_inputs = rng_engine_inputs; + p.dropout_prob = dropout_p; + } + constexpr auto kernel_fn = attention_kernel_batched; size_t smem_bytes = sizeof(typename Kernel::SharedStorage); if (smem_bytes > 0xc000) { @@ -286,7 +329,16 @@ std::tuple efficient_attention_forward_cutlass( })); AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(res, logsumexp); + + // uint64_t -> int64_t bitwise casting as PyTorch don't support uint64_t + // so just fake it as a int64_t + int64_t seed, offset; + if (use_dropout) { + std::memcpy(&seed, &rng_engine_inputs.seed_, sizeof(seed)); + std::memcpy(&offset, &rng_engine_inputs.offset_.val, sizeof(offset)); + } + + return std::make_tuple(res, logsumexp, seed, offset); #endif } } // namespace diff --git a/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h b/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h index 9265b52b3c..b439ca100a 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h +++ b/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h @@ -49,8 +49,7 @@ struct RegisterOps { int8_t thread_id, int8_t warp_id, int16_t max_col, - typename T::TensorCoord const& tile_offset, - float scaling) { + typename T::TensorCoord const& tile_offset) { // Convert to `accum_t` (rather than double) constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E if (!kIsFirst) { @@ -78,10 +77,10 @@ struct RegisterOps { [&](int accum_m) { // Having 4x atomicMax seems faster than reduce within warp // first... - atomicMaxFloat(&mi[accum_m], max * scaling); + atomicMaxFloat(&mi[accum_m], max); }); } - frag = cutlass::multiplies()(scaling * kLog2e, frag); + frag = cutlass::multiplies()(kLog2e, frag); // Make sure we all share the update values for `mi` __syncthreads(); diff --git a/xformers/csrc/attention/cuda/fmha/kernel_backward.h b/xformers/csrc/attention/cuda/fmha/kernel_backward.h index bff217c498..d6690497c7 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_backward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_backward.h @@ -1,14 +1,28 @@ #pragma once #include +#include #include #include - +#include + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" #include "debug_utils.h" #include "gemm_kernel_utils.h" @@ -34,6 +48,7 @@ #include "find_default_mma.h" #include "gemm/custom_mma.h" #include "mma_from_smem.h" +#include "transform/tile_smem_loader.h" #include @@ -141,6 +156,8 @@ template < typename scalar_t_, // run optimized kernel because memory accesses will be aligned bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout, // upperbound on `max(value.shape[-1], query.shape[-1])` int kMaxK = std::numeric_limits::max()> struct AttentionBackwardKernel { @@ -157,6 +174,7 @@ struct AttentionBackwardKernel { scalar_t* query_ptr; // [Mq, nH, K] scalar_t* key_ptr; // [Mk, nH, K] scalar_t* value_ptr; // [Mk, nH, Kv] + scalar_t* bias_ptr = nullptr; lse_scalar_t* logsumexp_ptr; // [nH, Mq] scalar_t* output_ptr; // [Mq, nH, Kv] scalar_t* grad_output_ptr; // [Mq, nH, Kv] @@ -166,6 +184,8 @@ struct AttentionBackwardKernel { output_t* grad_query_ptr; // [Mq, nH, K] output_t* grad_key_ptr; // [Mk, nH, K] output_t* grad_value_ptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + // Accumulators union { output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] @@ -188,9 +208,17 @@ struct AttentionBackwardKernel { int32_t q_strideM; int32_t k_strideM; int32_t v_strideM; + int32_t bias_strideM; int32_t gO_strideM; + int32_t gB_strideM; int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise + // dropout + at::PhiloxCudaState rng_engine_inputs; + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset; + float dropout_prob; + CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; } @@ -210,10 +238,12 @@ struct AttentionBackwardKernel { int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; + int32_t bias_strideH; int64_t o_strideB; int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; + int64_t bias_strideB; int64_t lse_strideM; int32_t num_batches; @@ -221,10 +251,12 @@ struct AttentionBackwardKernel { int64_t gQ_strideB; int64_t gK_strideB; int64_t gV_strideB; + int64_t gB_strideB; int64_t gO_strideH; int64_t gQ_strideH; int64_t gK_strideH; int64_t gV_strideH; + int64_t gB_strideH; CUTLASS_DEVICE void advance_to_block() { int64_t batch_id = blockIdx.z; @@ -234,6 +266,9 @@ struct AttentionBackwardKernel { key_ptr += batch_id * k_strideB + head_id * k_strideH; value_ptr += batch_id * v_strideB + head_id * v_strideH; logsumexp_ptr += (batch_id * num_heads + head_id) * lse_strideM; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } output_ptr += batch_id * o_strideB + head_id * o_strideH; grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; delta_ptr += (batch_id * num_heads + head_id) * num_queries; @@ -241,6 +276,13 @@ struct AttentionBackwardKernel { grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); head_dim = warp_uniform(head_dim); head_dim_value = warp_uniform(head_dim_value); @@ -257,6 +299,7 @@ struct AttentionBackwardKernel { query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); logsumexp_ptr = warp_uniform(logsumexp_ptr); output_ptr = warp_uniform(output_ptr); grad_output_ptr = warp_uniform(grad_output_ptr); @@ -265,6 +308,7 @@ struct AttentionBackwardKernel { grad_query_ptr = warp_uniform(grad_query_ptr); grad_key_ptr = warp_uniform(grad_key_ptr); grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { assert(workspace_size() == 0 || workspace != nullptr); @@ -416,6 +460,18 @@ struct AttentionBackwardKernel { using Mma = typename MakeCustomMma::Mma; + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = TileSmemLoader< + scalar_t, + // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded + // row-major but needs to have transposed shape so we get the same + // elements. + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + // Epilogue to store to shared-memory in a format that we can use later for // the second matmul using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< @@ -464,11 +520,24 @@ struct AttentionBackwardKernel { false, // SplitKSerial typename GemmType::Operator>; + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulQK::AccumulatorSharedStorage>; + typename MatmulQK::AccumulatorSharedStorage, + kApplyDropout>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; @@ -489,27 +558,50 @@ struct AttentionBackwardKernel { using ThreadblockShape = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; - using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + + using ElementC = output_t; + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< scalar_t, // ElementA cutlass::layout::RowMajor, // LayoutA kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, scalar_t, // ElementB cutlass::layout::ColumnMajor, // LayoutB kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, - accum_t, // ElementC + ElementC, // ElementC cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator typename GemmType::OpClass, ArchTag, ThreadblockShape, WarpShape, typename GemmType::InstructionShape, - DefaultConfig::kStages, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreloadMmas && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial typename GemmType::Operator, - false, // AccumulatorsInRowMajor = false, cutlass::gemm::SharedMemoryClearOption::kNone>; - using MmaCore = typename DefaultMma::MmaCore; - using Mma = - typename MakeCustomMma::Mma; + using Mma = typename MakeCustomMma::Mma; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; // Epilogue to store to shared-memory in a format that we can use later for // the second matmul @@ -553,7 +645,8 @@ struct AttentionBackwardKernel { using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulDOIVJ::AccumulatorSharedStorage>; + typename MatmulDOIVJ::AccumulatorSharedStorage, + false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; @@ -597,12 +690,14 @@ struct AttentionBackwardKernel { using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulQK::AccumulatorSharedStorage>; + typename MatmulQK::AccumulatorSharedStorage, + false>; // kScaleOperandA using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, typename MatmulDOIVJ::AccumulatorSharedStorage, - kPreloadMmas>; + false, // kScaleOperandA + kPreloadMmas>; // kTransposeA using DefaultMmaFromSmem = typename cutlass::platform::conditional< DefaultMmaFromSmemT::kIsTransposedA, DefaultMmaFromSmemT, @@ -620,6 +715,26 @@ struct AttentionBackwardKernel { using AccumTileGmem = GmemTile; }; + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + // See https://fburl.com/gsheet/l5bltspl // for an illustration of how smem is used struct SharedStoragePrologue { @@ -630,13 +745,34 @@ struct AttentionBackwardKernel { union { struct { // p1 - after Q.K / dV / dO.V - typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed: + // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij + // are loaded for the computation of dVj. + // - to compute dPij = (dOi @ Vj.T) * Zij + // 6. used in dVj += (Pij.T * Zij) @ dOi + // 9. used in dPij = dPij_dropped * Zij + ZijSharedStorage zij; union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; }; + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; } p1; @@ -649,7 +785,11 @@ struct AttentionBackwardKernel { }; typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) - typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; } p2; @@ -686,15 +826,19 @@ struct AttentionBackwardKernel { printf(" persistent: %db\n", FSZ(persistent)); printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); printf(" p1: %db\n", FSZ(p1)); + printf(" bias: %db\n", FSZ(p1.bias)); printf(" attn_shared_storage: %db\n", FSZ(p1.attn_shared_storage)); + printf(" zij: %db\n", FSZ(p1.zij)); printf(" mm_gradV: %db\n", FSZ(p1.mm_gradV)); printf(" gradV_epilogue: %db\n", FSZ(p1.gradV_epilogue)); printf(" mm_doivj: %db\n", FSZ(p1.mm_doivj)); printf(" p2: %db\n", FSZ(p2)); + printf(" tmpT_shared_storage: %db\n", FSZ(p2.tmpT_shared_storage)); + printf(" tmp_shared_storage: %db\n", FSZ(p2.tmp_shared_storage)); printf(" mm_gradK: %db\n", FSZ(p2.mm_gradK)); printf(" mm_gradQ: %db\n", FSZ(p2.mm_gradQ)); + printf(" gradB_epilogue: %db\n", FSZ(p2.gradB_epilogue)); printf(" gradQ_epilogue: %db\n", FSZ(p2.gradQ_epilogue)); - printf(" tmp_shared_storage: %db\n", FSZ(p2.tmp_shared_storage)); printf(" p3: %db\n", FSZ(p3)); printf(" tmpT_shared_storage: %db\n", FSZ(p3.tmpT_shared_storage)); printf(" p4: %db\n", FSZ(p4)); @@ -710,12 +854,15 @@ struct AttentionBackwardKernel { FIELD(persistent, di) FIELD(persistent, mm_qk_k) + FIELD(p1, bias) FIELD(p1, attn_shared_storage) + FIELD(p1, zij) FIELD(p1, mm_gradV) FIELD(p1, gradV_epilogue) FIELD(p1, mm_doivj) FIELD(p2, mm_gradK) FIELD(p2, mm_gradQ) + FIELD(p2, gradB_epilogue) FIELD(p2, gradQ_epilogue) FIELD(p2, tmp_shared_storage) FIELD(p3, tmpT_shared_storage) @@ -739,7 +886,21 @@ struct AttentionBackwardKernel { struct { // p2 - compute gradV - typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed: + // - in this step, where it is used to compute Pij_dropped = Pij * Zij + // on the + // fly as fragments of Pij are loaded for the computation of dVj. + // - later to compute dPij = (dOi @ Vj.T) * Zij + ZijSharedStorage zij; + union { typename MatmulGradV::Mma::SharedStorage mm_gradV; typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; @@ -748,9 +909,20 @@ struct AttentionBackwardKernel { struct { // p3 - DO.V matmul - typename MatmulQK::AccumulatorSharedStorage - attn_shared_storage; // (from p2) - typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from p2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // (from p2) - Zij for computing dPij = dPij_dropped * Zij + ZijSharedStorage zij; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; } p3; struct { @@ -805,10 +977,13 @@ struct AttentionBackwardKernel { FIELD(persistent, di) FIELD(p1, mm_qk_k) FIELD(p1, mm_qk_q) + FIELD(p2, bias) FIELD(p2, attn_shared_storage) + FIELD(p2, zij) FIELD(p2, mm_gradV) FIELD(p2, gradV_epilogue) FIELD(p3, mm_doivj) + FIELD(p3, gradB_epilogue) FIELD(p4, tmpT_shared_storage) FIELD(p4, tmp_shared_storage) FIELD(p4, mm_gradQ) @@ -878,6 +1053,25 @@ struct AttentionBackwardKernel { } OutputFragments output_frags; + + curandStatePhilox4_32_10_t rng_state_init; + if (kApplyDropout) { + auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &rng_state_init); + } + int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -887,12 +1081,22 @@ struct AttentionBackwardKernel { (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI; for (; query_start < query_end; query_start += kBlockSizeI) { processBlockIJ( - shared_storage, output_frags, p, query_start, key_start); + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init); } // last (partial) query if (query_start < p.num_queries) { processBlockIJ( - shared_storage, output_frags, p, query_start, key_start); + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init); } if (kOutputInRF) { writeFragsToGmem(shared_storage, output_frags, p, key_start); @@ -906,7 +1110,12 @@ struct AttentionBackwardKernel { query_start < p.num_queries; query_start += kBlockSizeI) { processBlockIJ( - shared_storage, output_frags, p, query_start, key_start); + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init); } if (kOutputInRF) { writeFragsToGmem(shared_storage, output_frags, p, key_start); @@ -934,7 +1143,8 @@ struct AttentionBackwardKernel { OutputFragments& output_frags, Params const& p, int32_t query_start, - int32_t key_start) { + int32_t key_start, + const curandStatePhilox4_32_10_t& curand_state_init) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = p.scale; int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; @@ -1070,6 +1280,39 @@ struct AttentionBackwardKernel { auto output_tile_coords = cutlass::MatrixCoord{ warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + // apply bias if applicable + if (p.bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + p.bias_ptr + query_start * p.bias_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id); + MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, where Pij is in register fragment and Bij is in shmem + auto lane_offset = MatmulQK::ScalingCoefsUpdater::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + if (skipBoundsChecks || + (accum_n < num_queries_in_block && + accum_m < num_keys_in_block)) { + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); + } + }, + [&](int accum_n) {}); + } + // Apply mask if (p.causal) { auto lane_offset = MatmulQK::ScalingCoefsUpdater::get_lane_offset( @@ -1102,6 +1345,59 @@ struct AttentionBackwardKernel { warp_id, lane_id, output_tile_coords); + + // if we are using dropout, compute Zij, writing it to shared memory. + // each element of Zij is: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + if (kApplyDropout) { + auto zij = shared_storage.zij().accum_ref(); + // each thread generates a contiguous sequence of elements in Zij, all + // in the same row. the reason they have to come from the same row is + // that sampling random numbers from a contiguous random number sequence + // is much more efficient than jumping around, and the linear offset of + // each element of Z (the global matrix) maps to an offset in a random + // number sequence. for Z, the end of a row and the beginning of the + // next have adjacent offsets, but for Zij (tile of global matrix), this + // is not necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = cutlass::fast_min( + num_threads / num_queries_in_block, num_keys_in_block); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); + + const int thread_i = thread_id / threads_per_row; + const int thread_start_j = + (thread_id % threads_per_row) * elts_per_thread; + + if (thread_i < num_queries_in_block && + thread_start_j < num_keys_in_block) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + (query_start + thread_i) * p.num_keys + + (key_start + thread_start_j), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // generate elements of Zij, 4 elements at a time + for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); + zij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + // we'll write Zij transposed since attention is also transposed + // during the matmul to compute dV. + zij.at({zij_start_col_idx + quad_idx, thread_i}) = + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + } __syncthreads(); } @@ -1131,13 +1427,20 @@ struct AttentionBackwardKernel { thread_id, no_offset); + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi Mma mma( shared_storage.mm_gradV(), - shared_storage.attn_shared_storage(), + // operand A: Pij + typename MatmulGradV::WarpIteratorA( + shared_storage.attn_shared_storage().accum_ref(), lane_id), + // if we're using dropout, operand A is Pij_dropped = Pij * Zij + // which is computed on the fly as fragments of Pij are loaded in + typename Mma::WarpIteratorAScale( + shared_storage.zij().accum_ref(), lane_id), thread_id, warp_id, - lane_id, - problem_size.k()); + lane_id); int storage_id = col / MatmulGradV::ThreadblockShape::kN; AccumTileGmem gmem_tile{ @@ -1234,10 +1537,33 @@ struct AttentionBackwardKernel { { using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< typename Mma::Operator::IteratorC, - typename MatmulDOIVJ::DefaultMma::MmaCore::ElementC, + typename MatmulDOIVJ::ElementAccum, kWarpSize>::Updater; auto lane_offset = RegistersIter::get_lane_offset( lane_id, warp_id, output_tile_coords); + + // if dropout was used, compute dPij = dPij_dropped * Zij + // Zij was written to shared memory earlier, and the elementwise + // multiplication occurs on a fragment of dPij_dropped + if (kApplyDropout) { + const auto zij = shared_storage.zij().accum_ref(); + + RegistersIter::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + const int global_query_idx = query_start + accum_m; + const int global_key_idx = key_start + accum_n; + + if (skipBoundsChecks || + (global_query_idx < p.num_queries && + global_key_idx < p.num_keys)) { + accum[idx] *= zij.at({accum_n, accum_m}); + } + }, + [&](int accum_m) {}); + } + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); accum_t current_di; typename Mma::FragmentC fragment_attn, fragment_di; @@ -1259,7 +1585,35 @@ struct AttentionBackwardKernel { [&](int accum_m) { }); - accum = (accum - fragment_di) * fragment_attn * scale; + // dSij = (dPij - Di) * Pij + accum = (accum - fragment_di) * fragment_attn; + + // store bias gradient tile dBij to global memory, + // where dBij = dSij = Pij * (dPij - Di) + if (p.grad_bias_ptr != nullptr) { + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator + output_iter( + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: + Params{p.gB_strideM}, + // grad_bias_ptr is offset to point at beginning of + // matrix of shape (queries, keys) for a given + // (batch_id, head_id) the pointer arithmetic here produces + // a pointer to the start of the current tile within that + // matrix + p.grad_bias_ptr + query_start * p.gB_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + + // no-op epilogue operator - just casting and storing contents of + // accum to global memory + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue epilogue( + shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); + epilogue(output_op, output_iter, accum, output_iter); + } + + accum = accum * scale; + __syncthreads(); if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); diff --git a/xformers/csrc/attention/cuda/fmha/kernel_forward.h b/xformers/csrc/attention/cuda/fmha/kernel_forward.h index 2023924cb2..8847655b79 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_forward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_forward.h @@ -1,11 +1,24 @@ +#ifdef HAS_PYTORCH +#include +#include +#include +#include +#include +#include +#endif + +#include #include #include #include "cutlass/bfloat16.h" +#include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" #include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" #include "attention_scaling_coefs_updater.h" #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" @@ -28,6 +41,7 @@ #include "find_default_mma.h" #include "gemm_kernel_utils.h" #include "mma_from_smem.h" +#include "transform/tile_smem_loader.h" #include @@ -88,6 +102,7 @@ struct AttentionKernel { scalar_t* query_ptr; // [num_queries, num_heads, head_dim] scalar_t* key_ptr; // [num_keys, num_heads, head_dim] scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] int32_t* cu_seqlens_q_ptr = nullptr; int32_t* cu_seqlens_k_ptr = nullptr; @@ -111,18 +126,29 @@ struct AttentionKernel { int32_t q_strideM; int32_t k_strideM; int32_t v_strideM; + int32_t bias_strideM; // Everything below is only used in `advance_to_block` // and shouldn't use registers int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; + int32_t bias_strideH; + int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; + int32_t bias_strideB; + int32_t num_batches; int32_t num_heads; + // dropout + bool use_dropout; + at::PhiloxCudaState rng_engine_inputs; + unsigned long long dropout_batch_head_rng_offset; + float dropout_prob; + CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; } @@ -170,6 +196,9 @@ struct AttentionKernel { output_ptr += int64_t(q_start + query_start) * o_strideM() + head_id * head_dim_value; + if (attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } if (output_accum_ptr != nullptr) { output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + head_id * head_dim_value; @@ -183,6 +212,10 @@ struct AttentionKernel { batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; } + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + num_queries -= query_start; if (causal) { num_keys = cutlass::fast_min( @@ -195,6 +228,7 @@ struct AttentionKernel { query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); + attn_bias_ptr = warp_uniform(attn_bias_ptr); output_ptr = warp_uniform(output_ptr); output_accum_ptr = warp_uniform(output_accum_ptr); logsumexp_ptr = warp_uniform(logsumexp_ptr); @@ -275,6 +309,14 @@ struct AttentionKernel { kNumWarpsPerBlock, ""); + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + // Epilogue to store to shared-memory in a format that we can use later for // the second matmul using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< @@ -336,7 +378,8 @@ struct AttentionKernel { using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MM0::AccumulatorSharedStorage>; + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; @@ -373,7 +416,10 @@ struct AttentionKernel { struct SharedStorageEpilogueAtEnd : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 - typename MM0::AccumulatorSharedStorage si; + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; typename MM1::SharedStorageMM1 mm1; }; @@ -392,7 +438,10 @@ struct AttentionKernel { struct SharedStorageEpilogueInLoop : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 - typename MM0::AccumulatorSharedStorage si; + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; typename MM1::SharedStorageMM1 mm1; typename MM1::DefaultEpilogue::SharedStorage epilogue; }; @@ -443,6 +492,7 @@ struct AttentionKernel { auto& s_prime = shared_storage.s_prime; auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); if (thread_id() < kQueriesPerBlock) { @@ -477,6 +527,25 @@ struct AttentionKernel { {0, col}); }; + curandStatePhilox4_32_10_t curand_state_init; + if (p.use_dropout) { + const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &curand_state_init); + } + // Iterate through keys for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; iter_key_start += kKeysPerBlock) { @@ -569,6 +638,41 @@ struct AttentionKernel { (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + (my_warp_id / MM0::Mma::WarpCount::kM)}; + // multiply by scaling factor + accum = + cutlass::multiplies()(p.scale, accum); + + // apply attention bias if applicable + if (p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + // Mask out last if causal if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { auto query_start = blockIdx.x * kQueriesPerBlock; @@ -594,9 +698,7 @@ struct AttentionKernel { kFullColumns, ([&] { // Update `mi` from accum stored in registers - // Also updates `accum` with accum[i] <- - // exp(accum[i] * scale - // - mi) + // Also does accum[i] <- exp(accum[i] - mi) MM0::ScalingCoefsUpdater::update< kQueriesPerBlock, kFullColumns, @@ -611,8 +713,7 @@ struct AttentionKernel { thread_id(), warp_id(), p.num_keys - iter_key_start, - iteratorC_tile_offset, - p.scale); + iteratorC_tile_offset); })); })); @@ -628,6 +729,67 @@ struct AttentionKernel { __syncthreads(); + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } + // // MATMUL: Attn . V // Run the matmul `attn @ V` for a block of attn and V. diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu index 1587a831e0..c3d95b04b7 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false, false); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu index 83f4460a56..e6a55db9d6 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, false); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu new file mode 100644 index 0000000000..e1877331b3 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, true); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu new file mode 100644 index 0000000000..01f48d47a0 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu @@ -0,0 +1,24 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + true, + true, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + true, + true, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + true, + true, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + true, + true, + 128); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu new file mode 100644 index 0000000000..4483098da6 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, true, 64); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu index 0a4bf64032..b1310420c1 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu @@ -1,8 +1,24 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + true, + false, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + true, + false, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + true, + false, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + true, + false, + 128); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu index 3115393b89..444325a965 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu @@ -1,8 +1,24 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + true, + false, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + true, + false, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + true, + false, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + true, + false, + 64); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu new file mode 100644 index 0000000000..60cd1fd603 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false, true); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu new file mode 100644 index 0000000000..3c29ccd19e --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu @@ -0,0 +1,24 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + false, + true, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + false, + true, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + false, + true, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + false, + true, + 128); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu new file mode 100644 index 0000000000..3e0ee619bc --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu @@ -0,0 +1,24 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + false, + true, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + false, + true, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + false, + true, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + false, + true, + 64); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu index 3eae5e5a32..dc8a811b75 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu @@ -1,8 +1,24 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + false, + false, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + false, + false, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + false, + false, + 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + false, + false, + 128); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu index 508edac35f..6181148ebf 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu @@ -1,8 +1,24 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( + cutlass::bfloat16_t, + false, + false, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70( + cutlass::bfloat16_t, + false, + false, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75( + cutlass::bfloat16_t, + false, + false, + 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( + cutlass::bfloat16_t, + false, + false, + 64); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu index d77808f2ca..08e72e38a6 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, false); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu index e70770160f..9764de537d 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, false); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu new file mode 100644 index 0000000000..f68340c2b5 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, true); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu new file mode 100644 index 0000000000..7992f7f21f --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, true, 128); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu new file mode 100644 index 0000000000..4d0867f23c --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, true, 64); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu index 5eda46d818..19d6115030 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, false, 128); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu index 507793dce5..b0e177d03f 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, false, 64); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu new file mode 100644 index 0000000000..1bef0d76ab --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, true); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu new file mode 100644 index 0000000000..66b91286d0 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, true, 128); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu new file mode 100644 index 0000000000..661c4b97a1 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, true, 64); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu index 3e571f7d07..f4485ff6c8 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, false, 128); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu index 0839106577..d6ff34ea13 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, false, 64); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu index c868a720d0..4477366a6d 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, false); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu index 35f44b109d..b3bf2768f4 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, false); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, false); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu new file mode 100644 index 0000000000..4dbdc6db48 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, true); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu new file mode 100644 index 0000000000..1ffc2806ea --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, true, 128); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu new file mode 100644 index 0000000000..8ba17928d5 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, true, 64); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu index d35c2e412e..6d2d903a85 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, false, 128); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu index 520c449bbb..7906cae500 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, false, 64); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu new file mode 100644 index 0000000000..ce50ee776a --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, true); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, true); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu new file mode 100644 index 0000000000..e7b4da246c --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, true, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, true, 128); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu new file mode 100644 index 0000000000..de6c8b0999 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu @@ -0,0 +1,8 @@ +// This file is auto-generated. See "generate_kernels.sh" +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, true, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, true, 64); +#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu index 731252902c..8428f9ddbd 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, 128); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, false, 128); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, false, 128); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu index 9dcc68de8f..476685bf41 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu @@ -1,8 +1,8 @@ // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, 64); -INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, false, 64); +INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, false, 64); #endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh b/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh index 0d64812db0..08625b7654 100755 --- a/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh +++ b/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh @@ -6,33 +6,36 @@ IFS="," # BACKWARD kernel="BACKWARD" kernel_lower=`echo "\$kernel" | awk '{print tolower($0)}'` -for aligned in "false" "true"; do - for maxk in 64 128 ""; do - for dtype_name in "f32" "f16" "bf16"; do - case "$dtype_name" in - "f32") dtype="float" ;; - "f16") dtype="cutlass::half_t" ;; - "bf16") dtype="cutlass::bfloat16_t" ;; - esac - [[ $aligned = "true" ]] && s="_aligned" || s="" - [[ $maxk = "" ]] && s="${s}" || s="${s}_k$maxk" - [[ $maxk = "" ]] && maxk_code="" || maxk_code=", $maxk" - FNAME="${kernel_lower}_${dtype_name}${s}.cu" - echo $FNAME - cat < $FNAME +for enable_dropout in "false" "true"; do + for aligned in "false" "true"; do + for maxk in 64 128 ""; do + for dtype_name in "f32" "f16" "bf16"; do + case "$dtype_name" in + "f32") dtype="float" ;; + "f16") dtype="cutlass::half_t" ;; + "bf16") dtype="cutlass::bfloat16_t" ;; + esac + [[ $aligned = "true" ]] && s="_aligned" || s="" + [[ $enable_dropout = "true" ]] && s="${s}_dropout" || s="${s}" + [[ $maxk = "" ]] && s="${s}" || s="${s}_k$maxk" + [[ $maxk = "" ]] && maxk_code="" || maxk_code=", $maxk" + FNAME="${kernel_lower}_${dtype_name}${s}.cu" + echo $FNAME + cat < $FNAME // This file is auto-generated. See "generate_kernels.sh" #ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD #include "../kernel_backward.h" EOF - for sm in 50 70 75 80; do - echo "INSTANTIATE_ATTENTION_KERNEL_${kernel}_SM${sm}($dtype, $aligned$maxk_code);" >> $FNAME - done; - cat <> $FNAME + for sm in 50 70 75 80; do + echo "INSTANTIATE_ATTENTION_KERNEL_${kernel}_SM${sm}($dtype, $aligned, $enable_dropout$maxk_code);" >> $FNAME + done; + cat <> $FNAME #endif EOF + done; done; done; -done +done; # FORWARD kernel="FORWARD" diff --git a/xformers/csrc/attention/cuda/fmha/mma_from_smem.h b/xformers/csrc/attention/cuda/fmha/mma_from_smem.h index 4e5d6801cc..b280d81476 100644 --- a/xformers/csrc/attention/cuda/fmha/mma_from_smem.h +++ b/xformers/csrc/attention/cuda/fmha/mma_from_smem.h @@ -43,16 +43,20 @@ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/vector_iterator.h" #include "attention_scaling_coefs_updater.h" #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" #include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" #include "epilogue_thread_apply_logsumexp.h" #include "gemm_kernel_utils.h" @@ -246,6 +250,78 @@ class MmaBaseFromSharedMemory { : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // using size 1 is kind of a hack to get around arrays of zero-sized objects + // not being allowed. the compiler is probably smart enough to wipe it out + // anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + //////////////////////////////////////////////////////////////////////////////// // Taken from // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -259,6 +335,10 @@ template < // BEGIN smem /// Iterates over the intermediate accumulator tile in shared memory typename WarpIteratorA, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, // Accumulator type typename AccumulatorSharedStorage, // END smem @@ -297,6 +377,15 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory using ElementC = ElementC_; ///< Data type of accumulator matrix @@ -333,8 +422,20 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< private: using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + using WarpFragmentB = typename Operator::FragmentB; + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + protected: // /// Iterator to write threadblock-scoped tile of A operand to shared memory // SmemIteratorA smem_iterator_A_; @@ -346,7 +447,46 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< /// accumulator tile WarpIteratorA warp_tile_iterator_A_; + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp iterator over A tile held in shared memory + WarpIteratorA warp_iter_a, + // warp iterator over A_scale tile held in shared memory + WarpIteratorAScale warp_iter_a_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(warp_iter_a), + warp_tile_iterator_A_scale_(warp_iter_a_scale), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + /// Construct from tensor references CUTLASS_DEVICE MmaPipelinedFromSharedMemory( @@ -429,19 +569,26 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< __syncthreads(); + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + // Pair of fragments used to overlap shared memory loads and math // instructions WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; WarpFragmentB warp_frag_B[2]; warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); warp_frag_B[0].clear(); this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); this->warp_tile_iterator_B_.load(warp_frag_B[0]); ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; ++this->warp_tile_iterator_B_; Operator warp_mma; @@ -503,9 +650,12 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< (warp_mma_k + 1) % Base::kWarpGemmIterations); this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; ++this->warp_tile_iterator_B_; if (warp_mma_k == 0) { @@ -521,7 +671,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< warp_mma( accum, - warp_frag_A[warp_mma_k % 2], + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), warp_frag_B[warp_mma_k % 2], accum); } @@ -541,6 +692,10 @@ template < typename Shape1_, /// Iterates over the intermediate accumulator tile in shared memory typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, // Accumulator type typename AccumulatorSharedStorage, /// Iterates over tiles of B operand in global memory @@ -580,7 +735,14 @@ class MmaMultistageFromSharedMemory using SmemIteratorB1 = SmemIteratorB1_; using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate ///< accumulator tile in shared memory - + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; ///< Data type of accumulator matrix using ElementC = ElementC_; ///< Layout of accumulator matrix @@ -628,10 +790,20 @@ class MmaMultistageFromSharedMemory private: using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; using WarpLoadedFragmentB1 = typename Operator1::FragmentB; using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + private: // // Data members @@ -641,12 +813,54 @@ class MmaMultistageFromSharedMemory /// accumulator tile WarpIteratorA1 warp_tile_iterator_A1_; + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB1 smem_iterator_B1_; bool prologue_done_; public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp level iterator over operand A tile kept in shared memory + WarpIteratorA1 warp_tile_iterator_A1, + // warp level iterator over operand A elementwise scale tile kept in + // shared memory. + WarpIteratorAScale warp_tile_iterator_A1_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(warp_tile_iterator_A1), + warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + /// Construct from tensor references CUTLASS_DEVICE MmaMultistageFromSharedMemory( @@ -842,9 +1056,13 @@ class MmaMultistageFromSharedMemory cutlass::arch::cp_async_wait(); __syncthreads(); + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + // Pair of fragments used to overlap shared memory loads and math // instructions WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; @@ -854,6 +1072,9 @@ class MmaMultistageFromSharedMemory warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); ++warp_tile_iterator_A1_; + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); ++this->warp_tile_iterator_B_; @@ -864,7 +1085,8 @@ class MmaMultistageFromSharedMemory warp_mma1.transform( warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], - warp_loaded_frag_A1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), warp_loaded_frag_B1[0]); // tf32x3 kernels use staging accumulation. warp_mma uses a temporary @@ -909,17 +1131,22 @@ class MmaMultistageFromSharedMemory warp_mma_k < Base::kWarpGemmIterations1 - 1) { warp_tile_iterator_A1_.load( warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B_.load( warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); } ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; ++this->warp_tile_iterator_B_; if (warp_mma_k > 0) warp_mma1.transform( warp_transformed_frag_A1[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], - warp_loaded_frag_A1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), warp_loaded_frag_B1[warp_mma_k % 2]); if (platform::is_same< @@ -1009,7 +1236,9 @@ class MmaMultistageFromSharedMemory warp_mma1.transform( warp_transformed_frag_A1[(warp_mma_k + 1) % 2], warp_transformed_frag_B1[(warp_mma_k + 1) % 2], - warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); } } @@ -1119,6 +1348,9 @@ struct DefaultWarpIteratorAFromSharedMemory< template < typename Mma_, typename AccumulatorSharedStorage, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, bool kTransposeA = false> struct DefaultMmaFromSharedMemory; @@ -1151,6 +1383,9 @@ template < /// Transformation applied to B operand typename TransformB_, typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, bool kTransposeA> struct DefaultMmaFromSharedMemory< MmaPipelined< @@ -1165,6 +1400,7 @@ struct DefaultMmaFromSharedMemory< TransformA_, TransformB_>, AccumulatorSharedStorage_, + kScaleOperandA, kTransposeA> { static constexpr int kWarpSize = 32; using SmemAccumulatorLayout = cutlass::layout::RowMajor; @@ -1198,6 +1434,7 @@ struct DefaultMmaFromSharedMemory< using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< Shape_, WarpIteratorA, + kScaleOperandA, AccumulatorSharedStorage_, IteratorB, SmemIteratorB_, @@ -1238,6 +1475,9 @@ template < /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear, typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, bool kTransposeA> struct DefaultMmaFromSharedMemory< MmaMultistage< @@ -1254,6 +1494,7 @@ struct DefaultMmaFromSharedMemory< Stages, SharedMemoryClear>, AccumulatorSharedStorage_, + kScaleOperandA, kTransposeA> { static constexpr int kWarpSize = 32; @@ -1301,6 +1542,7 @@ struct DefaultMmaFromSharedMemory< typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< Shape_, WarpIteratorA, + kScaleOperandA, AccumulatorSharedStorage_, IteratorB, SmemIteratorB_, diff --git a/xformers/csrc/attention/cuda/fmha/pytorch_utils.h b/xformers/csrc/attention/cuda/fmha/pytorch_utils.h new file mode 100644 index 0000000000..b0ec5e9705 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/pytorch_utils.h @@ -0,0 +1,37 @@ +#include +#include + +/** + * kernels expect 4D bias/bias.grad with shape + * (batch_sz, n_heads, n_queries, n_keys). common bias shapes users may pass + * are: + * - (n_queries, n_keys) + * - (batch_sz * n_heads, n_queries, n_keys) + * - (batch_sz, n_heads, n_queries, n_keys) + * + * expand the bias as needed - be careful to only create a view with different + * shape/strides, no copies allowed. + */ +inline at::Tensor get_bias_4d_view( + const at::Tensor& bias, + int batch_sz, + int n_heads, + int n_queries, + int n_keys) { + TORCH_CHECK(bias.size(-2) == n_queries); + TORCH_CHECK(bias.size(-1) == n_keys); + switch (bias.dim()) { + case 2: // (n_queries, n_keys) - broadcast across all batches and heads + return bias.unsqueeze(0).unsqueeze(0).expand( + {batch_sz, n_heads, n_queries, n_keys}); + case 3: // (batch_sz * n_heads, n_queries, n_keys) - just reshape + TORCH_CHECK(bias.size(0) == batch_sz * n_heads); + return bias.view({batch_sz, n_heads, n_queries, n_keys}); + case 4: // (batch_sz, n_heads, n_queries, n_keys) - do nothing + TORCH_CHECK(bias.size(0) == batch_sz); + TORCH_CHECK(bias.size(1) == n_heads) + return bias; + default: + TORCH_CHECK(false, "bias can only have ndims in {2, 3, 4}"); + } +} diff --git a/xformers/csrc/attention/cuda/fmha/transform/tile_smem_loader.h b/xformers/csrc/attention/cuda/fmha/transform/tile_smem_loader.h new file mode 100644 index 0000000000..fc7678ccda --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/transform/tile_smem_loader.h @@ -0,0 +1,57 @@ +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index ebac4b9ceb..431365b317 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -72,6 +72,10 @@ def forward(ctx, op: AttentionOp, *args: Any) -> Any: ctx.op_fw = op_fw ctx.op_bw = op_bw ctx.p = inp.p + # used for cutlass backward with dropout + ctx.rng_seed = op_ctx.rng_seed + ctx.rng_offset = op_ctx.rng_offset + ctx.scale = inp.scale ctx.attn_bias_ctx = attn_bias_ctx ctx.n_args = len(args) @@ -88,13 +92,6 @@ def deserialize_bias( @classmethod @torch.autograd.function.once_differentiable def backward(cls, ctx, grad): - assert all( - not ctx.needs_input_grad[i] for i in range(ctx.n_args) if i not in [1, 2, 3] - ), ( - "Only gradients to Q/K/V is implemented. " - "For instance, it's not possible to backpropagate through the attention mask" - ) - # Re-create context query, key, value, out, lse = ctx.saved_tensors attn_bias_tensor = ctx.attn_bias_tensor @@ -107,11 +104,20 @@ def backward(cls, ctx, grad): p=ctx.p, scale=ctx.scale, ) - op_ctx = Context(lse=lse, out=out, rng_state=rng_state) + op_ctx = Context( + lse=lse, + out=out, + rng_state=rng_state, + # rng_seed and rng_offset used for cutlass implementation + rng_seed=ctx.rng_seed, + rng_offset=ctx.rng_offset, + ) grads = _memory_efficient_attention_backward( ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw ) - return (None, grads.dq, grads.dk, grads.dv) + (None,) * (ctx.n_args - 3) + return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * ( + ctx.n_args - 2 + ) def memory_efficient_attention( @@ -262,10 +268,10 @@ def memory_efficient_attention_backward( scale: Optional[float] = None, *, op: Optional[Type[AttentionBwOpBase]] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Computes the gradient of the attention. - Returns a tuple (dq, dk, dv) + Returns a tuple (dq, dk, dv, db) See :attr:`xformers.ops.memory_efficient` for an explanation of the arguments. `lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad` """ @@ -282,7 +288,7 @@ def memory_efficient_attention_backward( grad, op=op, ) - return (gradients.dq, gradients.dk, gradients.dv) + return (gradients.dq, gradients.dk, gradients.dv, gradients.db) def _memory_efficient_attention( diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 122eaa35d5..27cdbaeb5f 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -145,6 +145,10 @@ class Context: out: torch.Tensor op_bw: Optional[Type["AttentionBwOpBase"]] = None rng_state: Optional[torch.Tensor] = None + # used for cutlass backward with dropout + rng_seed: Optional[int] = None + # used for cutlass backward with dropout + rng_offset: Optional[int] = None def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor: pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to @@ -164,6 +168,8 @@ class Gradients: dq: torch.Tensor dk: torch.Tensor dv: torch.Tensor + # bias gradient. None if there is no tensor bias or if it doesn't require grad + db: Optional[torch.Tensor] = None class AttentionOpBase(BaseOperator): @@ -277,6 +283,20 @@ class AttentionBwOpBase(AttentionOpBase): torch.half: 2e-2, torch.bfloat16: 0.1, } + SUPPORTS_ATTN_BIAS_GRAD = False + + @classmethod + def supports(cls, d: Inputs) -> bool: + if not super(AttentionBwOpBase, cls).supports(d): + return False + if ( + isinstance(d.attn_bias, torch.Tensor) + and d.attn_bias.requires_grad + and not cls.SUPPORTS_ATTN_BIAS_GRAD + ): + return False + + return True @classmethod def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: diff --git a/xformers/ops/fmha/cutlass.py b/xformers/ops/fmha/cutlass.py index 3bb64413f5..4f7052ce68 100644 --- a/xformers/ops/fmha/cutlass.py +++ b/xformers/ops/fmha/cutlass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional, Set, Tuple +from typing import Any, List, Mapping, Optional, Set, Tuple import torch @@ -80,8 +80,12 @@ class FwOp(AttentionFwOpBase): SUPPORTED_DEVICES: Set[str] = {"cuda"} SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16} SUPPORTED_MAX_K = 65536 - SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), LowerTriangularMask} - SUPPORTS_DROPOUT = False + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + torch.Tensor, + LowerTriangularMask, + } + SUPPORTS_DROPOUT = True SUPPORTS_CUSTOM_SCALE = True SUPPORTS_DIFFERENT_VALUE_EMBED = True SUPPORTS_TENSOR_WITH_SEQLEN = True @@ -97,26 +101,36 @@ class FwOp(AttentionFwOpBase): def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: - if inp.attn_bias is not None and not isinstance( - inp.attn_bias, LowerTriangularMask - ): + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: raise NotImplementedError("Unsupported attn_bias type") + uses_attn_bias = isinstance(inp.attn_bias, torch.Tensor) causal = isinstance(inp.attn_bias, LowerTriangularMask) cu_seqlen_k, cu_seqlen_q, max_seqlen_q = _get_seqlen_info(inp) - out, lse = cls.OPERATOR( + out, lse, rng_seed, rng_offset = cls.OPERATOR( query=inp.query, key=inp.key, value=inp.value, + attn_bias=inp.attn_bias if uses_attn_bias else None, cu_seqlens_q=cu_seqlen_q, cu_seqlens_k=cu_seqlen_k, max_seqlen_q=max_seqlen_q, + dropout_p=inp.p, compute_logsumexp=needs_gradient, causal=causal, scale=inp.scale, ) ctx: Optional[Context] = None if needs_gradient: - ctx = Context(lse=lse, out=out) + ctx = Context( + out=out, + lse=lse, + rng_seed=rng_seed, + rng_offset=rng_offset, + # cutlass forward is only compatible with cutlass backward if + # dropout is used (because of the way RNG states are passed and the + # way random numbers are generated during backward) + op_bw=BwOp if inp.p != 0 else None, + ) return out, ctx @classmethod @@ -125,6 +139,13 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: matmul_alignment_mn = _minimum_gemm_alignment(d) check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + + if isinstance(d.attn_bias, torch.Tensor): + bits_per_scalar = torch.finfo(d.attn_bias.dtype).bits + # restriction comes from each thread loading bias 128 bits at a time + if d.attn_bias.shape[-1] % (128 // bits_per_scalar) != 0: + reasons.append("bias tensor not aligned for 128 bit loads") + return reasons @@ -135,12 +156,26 @@ class BwOp(AttentionBwOpBase): SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES + SUPPORTS_ATTN_BIAS_GRAD = True SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED SUPPORTS_TENSOR_WITH_SEQLEN = False NAME = "cutlassB" + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 5e-4, + # increased from 9e-2, more opportunities for numerical errors when bias is + # used, noticed in gK on SM80 + torch.half: 9.5e-2, + torch.bfloat16: 7e-1, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 1e-1, + } + _TEST_K: List[int] = [ 32, # 64x64 kernel 128, # 64x128/128x128 kernel @@ -151,6 +186,7 @@ class BwOp(AttentionBwOpBase): def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d) matmul_alignment_mn = _minimum_gemm_alignment(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn) check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) @@ -169,26 +205,48 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: f"Sm{sm} does not have enough shared-memory to run this kernel" " - see https://github.com/facebookresearch/xformers/issues/517" ) + + if isinstance(d.attn_bias, torch.Tensor): + bits_per_scalar = torch.finfo(d.attn_bias.dtype).bits + # restriction comes from each thread loading bias 128 bits at a time + if d.attn_bias.shape[-1] % (128 // bits_per_scalar) != 0: + reasons.append("bias tensor not aligned for 128 bit loads") + return reasons @classmethod def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: - if inp.attn_bias is not None and not isinstance( - inp.attn_bias, LowerTriangularMask - ): + if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES: raise NotImplementedError("Unsupported attn_bias type") + uses_attn_bias = isinstance(inp.attn_bias, torch.Tensor) causal = isinstance(inp.attn_bias, LowerTriangularMask) dtype = inp.query.dtype force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5) - (grad_q, grad_k, grad_v,) = cls.OPERATOR( + (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR( grad.to(dtype), inp.query, inp.key, inp.value, + inp.attn_bias if uses_attn_bias else None, ctx.get_padded_lse(32, force_pad_inf=force_pad_inf), ctx.out.to(dtype), + dropout_p=inp.p, + # if not using dropout, seed and offset are irrelevant but still expected + # in function signature so just pass 0 + # seed and offset could be None if a different FW op other than cutlass + # was used. + rng_seed=ctx.rng_seed if inp.p != 0 else 0, + rng_offset=ctx.rng_offset if inp.p != 0 else 0, causal=causal, scale=inp.scale, ) - return Gradients(dq=grad_q, dk=grad_k, dv=grad_v) + + # c++/CUDA implementation returns an uninitialized tensor if bias doesn't + # require grad + if not ( + isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad + ): + grad_bias = None + + return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)