Skip to content

Commit

Permalink
Split bwd into more .cu files to speed up compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 23, 2024
1 parent 5ca83a9 commit 65f723b
Show file tree
Hide file tree
Showing 89 changed files with 304 additions and 153 deletions.
4 changes: 3 additions & 1 deletion csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
});
});
});
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);
}
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_bwd_launch_template.h"

template<>
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);
}
Loading

0 comments on commit 65f723b

Please sign in to comment.