Skip to content

Commit

Permalink
Clean up softcapping bwd a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 23, 2024
1 parent 751c762 commit 5ca83a9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ Thanks to @beginlner for this contribution.
### 2.6: Softcapping.

Support attention with softcapping, as used in Gemma-2 and Grok models.
Thanks to @Narsil for this contribution.
Thanks to @Narsil and @lucidrains for this contribution.

## Performance

Expand Down
21 changes: 5 additions & 16 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); }

// Softcapping - calculating dTanh and scaling dS later with it
auto dtanh = ([&]{
if constexpr (Is_softcap) {
Tensor _dtanh = make_tensor_like(scores);
flash::calculate_dtanh(scores, _dtanh, params.softcap);
return _dtanh;
}
else {
return nullptr;
}
}());
Tensor dtanh = make_tensor_like(scores);
if constexpr (Is_softcap) {
flash::calculate_dtanh(scores, dtanh, params.softcap);
}

// Alibi
if (Has_alibi) {
Expand Down Expand Up @@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {

float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));

if constexpr (Is_softcap) {
scaled_ds *= dtanh(mi, ni);
}

if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
Expand Down

0 comments on commit 5ca83a9

Please sign in to comment.