Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
- FW: Drop support for dropout if pytorch is not installed
- Use rng_state in Context to store seed/offset for dropout
- Add test to ensure we can't combine flash+cutlass's dropouts

ghstack-source-id: c5e05a1994b9c20fc27b071c3bfefbb4174987a2
Pull Request resolved: https://github.com/fairinternal/xformers/pull/434

__original_commit__ = fairinternal/xformers@408fefe5506c92b9c58444620c45bf5159b7fb39
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 19, 2023
1 parent 2e65321 commit 8dab253
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 39 deletions.
11 changes: 3 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## TBD
### Fixed
### Added
- Added tensor attn bias support to CUTLASS FlashAttention
- Added tensor attn bias grad support to CUTLASS FlashAttention
- Added dropout support to CUTLASS FlashAttention


## [0.0.16] - 2023-01-12
Expand Down Expand Up @@ -42,14 +45,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- fMHA: Added CUTLASS-based kernel for `xformers.ops.memory_efficient_attention`. This kernel is automatically depending on the inputs, and works on any GPU after P100 [facebookresearch/xformers#362]

## [0.0.15] - 2022-12-13
### Fixed

### Added
- Added tensor attn bias support to CUTLASS FlashAttention
- Added tensor attn bias grad support to CUTLASS FlashAttention
- Added dropout support to CUTLASS FlashAttention

## [0.0.12] - 2022-08-08
### Fixed
- Removed duplicated biases in the FusedMLP layers [facebookresearch/xformers#317]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,3 +1216,20 @@ def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]):
except ValueError:
q = q.contiguous()
fmha.memory_efficient_attention(q, q, q, op=(op, None))


@sm75_or_better_only
def test_unsupported_dropout_combine_flash_cutlass() -> None:
q = torch.empty(
[1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True
)
with pytest.raises(ValueError):
out = fmha.memory_efficient_attention(
q, q, q, p=0.1, op=(fmha.cutlass.FwOp, fmha.flash.BwOp)
)
out.backward(out)
with pytest.raises(ValueError):
out = fmha.memory_efficient_attention(
q, q, q, p=0.1, op=(fmha.flash.FwOp, fmha.cutlass.BwOp)
)
out.backward(out)
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct RegisterOps {
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename T::TensorCoord const& tile_offset) {
typename T::TensorCoord const& tile_offset,
float scaling) {
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
Expand Down Expand Up @@ -77,10 +78,10 @@ struct RegisterOps {
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max);
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<typename T::Fragment>()(kLog2e, frag);
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);

// Make sure we all share the update values for `mi`
__syncthreads();
Expand Down
33 changes: 33 additions & 0 deletions xformers/csrc/attention/cuda/fmha/debug_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,36 @@ constexpr __string_view __get_type_name() {
int(ps.m()), \
int(ps.n()), \
int(ps.k()))

template <typename Iterator, typename LaneOffsetT, typename AccumT>
CUTLASS_DEVICE void print_warp_accum(
AccumT accum,
LaneOffsetT lane_offset,
int32_t num_rows,
int32_t num_cols) {
bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
for (int row = 0; row < num_rows; ++row) {
for (int col = 0; col < num_cols; ++col) {
if (col % 32 == 0) {
if (is_main) {
printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
}
__syncthreads();
}
Iterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (row == accum_m && col == accum_n) {
printf(" %6.1f", float(accum[idx]));
}
},
[&](int accum_m) {});
__syncthreads();
}
if (is_main) {
printf("\n");
}
}
}
16 changes: 13 additions & 3 deletions xformers/csrc/attention/cuda/fmha/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ struct AttentionBackwardKernel {
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t bias_strideM;
int32_t bias_strideM = 0;
int32_t gO_strideM;
int32_t gB_strideM;
int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise
Expand Down Expand Up @@ -238,12 +238,12 @@ struct AttentionBackwardKernel {
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH;
int32_t bias_strideH = 0;
int64_t o_strideB;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t bias_strideB;
int64_t bias_strideB = 0;
int64_t lse_strideM;
int32_t num_batches;

Expand Down Expand Up @@ -1016,13 +1016,23 @@ struct AttentionBackwardKernel {
CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
XFORMERS_CHECK(p.lse_strideM % 8 == 0, "LSE is not correctly aligned");
XFORMERS_CHECK(
p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideB % kMinimumAlignment == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideH % kMinimumAlignment == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideM % kMinimumAlignment == 0,
"attn_bias is not correctly aligned");
return true;
}

Expand Down
31 changes: 22 additions & 9 deletions xformers/csrc/attention/cuda/fmha/kernel_forward.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#ifdef HAS_PYTORCH
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#endif

Expand Down Expand Up @@ -126,28 +122,30 @@ struct AttentionKernel {
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t bias_strideM;
int32_t bias_strideM = 0;

// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH;
int32_t bias_strideH = 0;

int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int32_t bias_strideB;
int32_t bias_strideB = 0;

int32_t num_batches;
int32_t num_heads;

// dropout
bool use_dropout;
at::PhiloxCudaState rng_engine_inputs;
unsigned long long dropout_batch_head_rng_offset;
float dropout_prob;
#ifdef HAS_PYTORCH
at::PhiloxCudaState rng_engine_inputs;
#endif

CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value * num_heads;
Expand Down Expand Up @@ -466,6 +464,7 @@ struct AttentionKernel {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
Expand All @@ -478,6 +477,15 @@ struct AttentionKernel {
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
return true;
}

Expand Down Expand Up @@ -527,6 +535,7 @@ struct AttentionKernel {
{0, col});
};

#ifdef HAS_PYTORCH
curandStatePhilox4_32_10_t curand_state_init;
if (p.use_dropout) {
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
Expand All @@ -545,6 +554,7 @@ struct AttentionKernel {
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
&curand_state_init);
}
#endif

// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
Expand Down Expand Up @@ -713,7 +723,8 @@ struct AttentionKernel {
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset);
iteratorC_tile_offset,
1.0f);
}));
}));

Expand All @@ -729,6 +740,7 @@ struct AttentionKernel {

__syncthreads();

#ifdef HAS_PYTORCH
// apply dropout (if applicable) after we've written Pij to smem.
// dropout is applied by multiplying each element of Pij by:
// - 0 with probability dropout_p
Expand Down Expand Up @@ -789,6 +801,7 @@ struct AttentionKernel {
}
__syncthreads(); // p.use_dropout should have same value kernel-wide
}
#endif

//
// MATMUL: Attn . V
Expand Down
10 changes: 2 additions & 8 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def forward(ctx, op: AttentionOp, *args: Any) -> Any:
ctx.op_fw = op_fw
ctx.op_bw = op_bw
ctx.p = inp.p
# used for cutlass backward with dropout
ctx.rng_seed = op_ctx.rng_seed
ctx.rng_offset = op_ctx.rng_offset

ctx.scale = inp.scale
ctx.attn_bias_ctx = attn_bias_ctx
Expand Down Expand Up @@ -108,9 +105,6 @@ def backward(cls, ctx, grad):
lse=lse,
out=out,
rng_state=rng_state,
# rng_seed and rng_offset used for cutlass implementation
rng_seed=ctx.rng_seed,
rng_offset=ctx.rng_offset,
)
grads = _memory_efficient_attention_backward(
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
Expand Down Expand Up @@ -268,7 +262,7 @@ def memory_efficient_attention_backward(
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionBwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the gradient of the attention.
Returns a tuple (dq, dk, dv, db)
Expand All @@ -288,7 +282,7 @@ def memory_efficient_attention_backward(
grad,
op=op,
)
return (gradients.dq, gradients.dk, gradients.dv, gradients.db)
return (gradients.dq, gradients.dk, gradients.dv)


def _memory_efficient_attention(
Expand Down
4 changes: 0 additions & 4 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,6 @@ class Context:
out: torch.Tensor
op_bw: Optional[Type["AttentionBwOpBase"]] = None
rng_state: Optional[torch.Tensor] = None
# used for cutlass backward with dropout
rng_seed: Optional[int] = None
# used for cutlass backward with dropout
rng_offset: Optional[int] = None

def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
Expand Down
21 changes: 17 additions & 4 deletions xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,15 @@ def apply(
ctx = Context(
out=out,
lse=lse,
rng_seed=rng_seed,
rng_offset=rng_offset,
# cutlass forward is only compatible with cutlass backward if
# dropout is used (because of the way RNG states are passed and the
# way random numbers are generated during backward)
op_bw=BwOp if inp.p != 0 else None,
)
if inp.p != 0:
ctx.rng_state = torch.tensor(
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
)
return out, ctx

@classmethod
Expand Down Expand Up @@ -222,6 +224,17 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
causal = isinstance(inp.attn_bias, LowerTriangularMask)
dtype = inp.query.dtype

rng_seed = rng_offset = 0
if inp.p != 0.0:
if (
ctx.rng_state is None
or ctx.rng_state.dtype != torch.int64
or ctx.rng_state.device.type != "cpu"
or ctx.rng_state.shape != (2,)
):
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
rng_seed, rng_offset = ctx.rng_state.tolist()

force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
grad.to(dtype),
Expand All @@ -236,8 +249,8 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
# in function signature so just pass 0
# seed and offset could be None if a different FW op other than cutlass
# was used.
rng_seed=ctx.rng_seed if inp.p != 0 else 0,
rng_offset=ctx.rng_offset if inp.p != 0 else 0,
rng_seed=rng_seed,
rng_offset=rng_offset,
causal=causal,
scale=inp.scale,
)
Expand Down

0 comments on commit 8dab253

Please sign in to comment.