diff --git a/CHANGELOG.md b/CHANGELOG.md index 975eca67ab..d2763446b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## TBD ### Fixed ### Added +- Added tensor attn bias support to CUTLASS FlashAttention +- Added tensor attn bias grad support to CUTLASS FlashAttention +- Added dropout support to CUTLASS FlashAttention ## [0.0.16] - 2023-01-12 @@ -42,14 +45,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - fMHA: Added CUTLASS-based kernel for `xformers.ops.memory_efficient_attention`. This kernel is automatically depending on the inputs, and works on any GPU after P100 [facebookresearch/xformers#362] -## [0.0.15] - 2022-12-13 -### Fixed - -### Added -- Added tensor attn bias support to CUTLASS FlashAttention -- Added tensor attn bias grad support to CUTLASS FlashAttention -- Added dropout support to CUTLASS FlashAttention - ## [0.0.12] - 2022-08-08 ### Fixed - Removed duplicated biases in the FusedMLP layers [facebookresearch/xformers#317] diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 1ebcbd9fc1..fad02e6cb0 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -1216,3 +1216,20 @@ def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]): except ValueError: q = q.contiguous() fmha.memory_efficient_attention(q, q, q, op=(op, None)) + + +@sm75_or_better_only +def test_unsupported_dropout_combine_flash_cutlass() -> None: + q = torch.empty( + [1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True + ) + with pytest.raises(ValueError): + out = fmha.memory_efficient_attention( + q, q, q, p=0.1, op=(fmha.cutlass.FwOp, fmha.flash.BwOp) + ) + out.backward(out) + with pytest.raises(ValueError): + out = fmha.memory_efficient_attention( + q, q, q, p=0.1, op=(fmha.flash.FwOp, fmha.cutlass.BwOp) + ) + out.backward(out) diff --git a/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h b/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h index b439ca100a..9265b52b3c 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h +++ b/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h @@ -49,7 +49,8 @@ struct RegisterOps { int8_t thread_id, int8_t warp_id, int16_t max_col, - typename T::TensorCoord const& tile_offset) { + typename T::TensorCoord const& tile_offset, + float scaling) { // Convert to `accum_t` (rather than double) constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E if (!kIsFirst) { @@ -77,10 +78,10 @@ struct RegisterOps { [&](int accum_m) { // Having 4x atomicMax seems faster than reduce within warp // first... - atomicMaxFloat(&mi[accum_m], max); + atomicMaxFloat(&mi[accum_m], max * scaling); }); } - frag = cutlass::multiplies()(kLog2e, frag); + frag = cutlass::multiplies()(scaling * kLog2e, frag); // Make sure we all share the update values for `mi` __syncthreads(); diff --git a/xformers/csrc/attention/cuda/fmha/debug_utils.h b/xformers/csrc/attention/cuda/fmha/debug_utils.h index dc98bf78f9..add5b5a064 100644 --- a/xformers/csrc/attention/cuda/fmha/debug_utils.h +++ b/xformers/csrc/attention/cuda/fmha/debug_utils.h @@ -162,3 +162,36 @@ constexpr __string_view __get_type_name() { int(ps.m()), \ int(ps.n()), \ int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + Iterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/xformers/csrc/attention/cuda/fmha/kernel_backward.h b/xformers/csrc/attention/cuda/fmha/kernel_backward.h index d6690497c7..9626fd03a9 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_backward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_backward.h @@ -208,7 +208,7 @@ struct AttentionBackwardKernel { int32_t q_strideM; int32_t k_strideM; int32_t v_strideM; - int32_t bias_strideM; + int32_t bias_strideM = 0; int32_t gO_strideM; int32_t gB_strideM; int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise @@ -238,12 +238,12 @@ struct AttentionBackwardKernel { int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; - int32_t bias_strideH; + int32_t bias_strideH = 0; int64_t o_strideB; int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; - int64_t bias_strideB; + int64_t bias_strideB = 0; int64_t lse_strideM; int32_t num_batches; @@ -1016,6 +1016,7 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); XFORMERS_CHECK(p.lse_strideM % 8 == 0, "LSE is not correctly aligned"); XFORMERS_CHECK( p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); @@ -1023,6 +1024,15 @@ struct AttentionBackwardKernel { p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); XFORMERS_CHECK( p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideB % kMinimumAlignment == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideH % kMinimumAlignment == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned"); return true; } diff --git a/xformers/csrc/attention/cuda/fmha/kernel_forward.h b/xformers/csrc/attention/cuda/fmha/kernel_forward.h index 8847655b79..11f90fdaff 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_forward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_forward.h @@ -1,9 +1,5 @@ #ifdef HAS_PYTORCH -#include -#include #include -#include -#include #include #endif @@ -126,28 +122,30 @@ struct AttentionKernel { int32_t q_strideM; int32_t k_strideM; int32_t v_strideM; - int32_t bias_strideM; + int32_t bias_strideM = 0; // Everything below is only used in `advance_to_block` // and shouldn't use registers int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; - int32_t bias_strideH; + int32_t bias_strideH = 0; int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; - int32_t bias_strideB; + int32_t bias_strideB = 0; int32_t num_batches; int32_t num_heads; // dropout bool use_dropout; - at::PhiloxCudaState rng_engine_inputs; unsigned long long dropout_batch_head_rng_offset; float dropout_prob; +#ifdef HAS_PYTORCH + at::PhiloxCudaState rng_engine_inputs; +#endif CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; @@ -466,6 +464,7 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); XFORMERS_CHECK( p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); XFORMERS_CHECK( @@ -478,6 +477,15 @@ struct AttentionKernel { p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); XFORMERS_CHECK( p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); return true; } @@ -527,6 +535,7 @@ struct AttentionKernel { {0, col}); }; +#ifdef HAS_PYTORCH curandStatePhilox4_32_10_t curand_state_init; if (p.use_dropout) { const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); @@ -545,6 +554,7 @@ struct AttentionKernel { std::get<1>(seeds) + p.dropout_batch_head_rng_offset, &curand_state_init); } +#endif // Iterate through keys for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; @@ -713,7 +723,8 @@ struct AttentionKernel { thread_id(), warp_id(), p.num_keys - iter_key_start, - iteratorC_tile_offset); + iteratorC_tile_offset, + 1.0f); })); })); @@ -729,6 +740,7 @@ struct AttentionKernel { __syncthreads(); +#ifdef HAS_PYTORCH // apply dropout (if applicable) after we've written Pij to smem. // dropout is applied by multiplying each element of Pij by: // - 0 with probability dropout_p @@ -789,6 +801,7 @@ struct AttentionKernel { } __syncthreads(); // p.use_dropout should have same value kernel-wide } +#endif // // MATMUL: Attn . V diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index 431365b317..fe2b1902a7 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -72,9 +72,6 @@ def forward(ctx, op: AttentionOp, *args: Any) -> Any: ctx.op_fw = op_fw ctx.op_bw = op_bw ctx.p = inp.p - # used for cutlass backward with dropout - ctx.rng_seed = op_ctx.rng_seed - ctx.rng_offset = op_ctx.rng_offset ctx.scale = inp.scale ctx.attn_bias_ctx = attn_bias_ctx @@ -108,9 +105,6 @@ def backward(cls, ctx, grad): lse=lse, out=out, rng_state=rng_state, - # rng_seed and rng_offset used for cutlass implementation - rng_seed=ctx.rng_seed, - rng_offset=ctx.rng_offset, ) grads = _memory_efficient_attention_backward( ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw @@ -268,7 +262,7 @@ def memory_efficient_attention_backward( scale: Optional[float] = None, *, op: Optional[Type[AttentionBwOpBase]] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the gradient of the attention. Returns a tuple (dq, dk, dv, db) @@ -288,7 +282,7 @@ def memory_efficient_attention_backward( grad, op=op, ) - return (gradients.dq, gradients.dk, gradients.dv, gradients.db) + return (gradients.dq, gradients.dk, gradients.dv) def _memory_efficient_attention( diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 27cdbaeb5f..e6c599ac8e 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -145,10 +145,6 @@ class Context: out: torch.Tensor op_bw: Optional[Type["AttentionBwOpBase"]] = None rng_state: Optional[torch.Tensor] = None - # used for cutlass backward with dropout - rng_seed: Optional[int] = None - # used for cutlass backward with dropout - rng_offset: Optional[int] = None def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor: pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to diff --git a/xformers/ops/fmha/cutlass.py b/xformers/ops/fmha/cutlass.py index 4f7052ce68..724fea1580 100644 --- a/xformers/ops/fmha/cutlass.py +++ b/xformers/ops/fmha/cutlass.py @@ -124,13 +124,15 @@ def apply( ctx = Context( out=out, lse=lse, - rng_seed=rng_seed, - rng_offset=rng_offset, # cutlass forward is only compatible with cutlass backward if # dropout is used (because of the way RNG states are passed and the # way random numbers are generated during backward) op_bw=BwOp if inp.p != 0 else None, ) + if inp.p != 0: + ctx.rng_state = torch.tensor( + [rng_seed, rng_offset], dtype=torch.int64, device="cpu" + ) return out, ctx @classmethod @@ -222,6 +224,17 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: causal = isinstance(inp.attn_bias, LowerTriangularMask) dtype = inp.query.dtype + rng_seed = rng_offset = 0 + if inp.p != 0.0: + if ( + ctx.rng_state is None + or ctx.rng_state.dtype != torch.int64 + or ctx.rng_state.device.type != "cpu" + or ctx.rng_state.shape != (2,) + ): + raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}") + rng_seed, rng_offset = ctx.rng_state.tolist() + force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5) (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR( grad.to(dtype), @@ -236,8 +249,8 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: # in function signature so just pass 0 # seed and offset could be None if a different FW op other than cutlass # was used. - rng_seed=ctx.rng_seed if inp.p != 0 else 0, - rng_offset=ctx.rng_offset if inp.p != 0 else 0, + rng_seed=rng_seed, + rng_offset=rng_offset, causal=causal, scale=inp.scale, )