Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model #8533

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b8e4cc6
reduce x size from float2 to float
mzusman Aug 29, 2024
8991183
Support single chunk as input
mzusman Sep 1, 2024
ea0089f
working with grid dimensions x = batch size, y = max seqlen, same amount
mzusman Sep 5, 2024
cf60b69
final and initial states suppport
mzusman Sep 10, 2024
3998748
working with cache indices
mzusman Sep 10, 2024
4038d84
WIP - add varlen to ssm
mzusman Sep 11, 2024
949718f
Working version with splits , TBD clean up
mzusman Sep 12, 2024
3e90085
also working with init prefill
mzusman Sep 12, 2024
b01f705
Remove last channel kernel
mzusman Sep 16, 2024
b9144f2
Clean up causal_conv1d
mzusman Sep 16, 2024
d2b97fa
Clean up selective_scan kernels and torch bindings
mzusman Sep 16, 2024
869aaf1
fix tests
mzusman Sep 16, 2024
0addc82
Fix wrappers
mzusman Sep 16, 2024
c4fe338
take off requirement for stride -1 == 1
mzusman Sep 17, 2024
d6fe5fd
Update causal_conv1d_update to use the new kernel
mzusman Sep 17, 2024
f2411b3
ssm state to be able to use different dtypes (itype)
mzusman Sep 17, 2024
216462f
more causal_conv1d_update fixes
mzusman Sep 17, 2024
3c6ec5c
Add prefill chunking to jamba modeling file
mzusman Sep 17, 2024
4753054
Fix formating
mzusman Sep 17, 2024
1581443
remove print
mzusman Sep 17, 2024
1e08a4e
remove cruft and add comments
mzusman Sep 22, 2024
674e9f9
Add guards
mzusman Sep 22, 2024
1b9b3ba
Renaming and fix bug for short sequences
mzusman Sep 22, 2024
8e4f92d
renaming and add test for random cache indices and random has initial…
mzusman Sep 22, 2024
3a8632d
Add comments
mzusman Sep 22, 2024
6a9acb7
Add comment on jamba tests
mzusman Sep 22, 2024
9408615
has initial state as bool and add comments to jamba
mzusman Sep 22, 2024
254351e
Merge remote-tracking branch 'github/main' into varlen_mamba_causal_c…
mzusman Sep 22, 2024
d3d4e0f
Format
mzusman Sep 22, 2024
801cd7a
Some alignments with the changed from upstream
mzusman Sep 22, 2024
ddf1d5c
Remove cruft
mzusman Sep 25, 2024
8278263
Fix prefill chunking test
mzusman Sep 25, 2024
fbd1756
Remove unused returns
mzusman Sep 25, 2024
2153e03
Use varlen in Jamba
mzusman Sep 25, 2024
64c2f4b
Use decode kernels
mzusman Sep 25, 2024
b4515f7
Remove comment
mzusman Sep 25, 2024
693e818
Merge remote-tracking branch 'github/main' into varlen_mamba_causal_c…
mzusman Sep 25, 2024
82b3a2a
Fix opcheck for mamba ssm and causal conv1d
mzusman Sep 25, 2024
fee189d
Put back the figure
mzusman Sep 25, 2024
d74ee9c
WIP - fix opcheck tests
mzusman Sep 26, 2024
d4ddb12
Formating and fix opcheck tests
mzusman Sep 26, 2024
5528cfa
Add final state out the selective_scan_ref could fix tests fail
mzusman Sep 26, 2024
9fb3353
Fix test failures
mzusman Sep 27, 2024
f9bc2fa
Merge remote-tracking branch 'github/main' into varlen_mamba_causal_c…
mzusman Sep 27, 2024
9c6d140
format
mzusman Sep 27, 2024
1362e9b
renaming and sort out the set_params functions
mzusman Sep 29, 2024
b8580f5
add comment on final state assigment
mzusman Sep 29, 2024
209a6a9
renaming
mzusman Sep 29, 2024
c670e6d
renaming
mzusman Sep 29, 2024
30e7239
Renaming
mzusman Sep 29, 2024
9e7ecf2
Jamba adaptations
mzusman Sep 29, 2024
893fdf9
Fix jamba calls
mzusman Sep 29, 2024
e1c018b
Formating and renaming
mzusman Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
527 changes: 211 additions & 316 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions csrc/mamba/causal_conv1d/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct ConvParamsBase {
index_t out_c_stride;
index_t out_l_stride;

int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
Expand All @@ -35,6 +36,10 @@ struct ConvParamsBase {
void *__restrict__ out_ptr;

void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;

// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
Expand All @@ -52,6 +57,11 @@ struct ConvParamsBase {
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;

void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};


Expand Down
29 changes: 9 additions & 20 deletions csrc/mamba/mamba_ssm/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ struct SSMParamsBase {
void *__restrict__ delta_ptr;
void *__restrict__ delta_bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ x_ptr;
void *__restrict__ ssm_states_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;

void *__restrict__ query_start_loc_ptr;
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;

};


Expand Down Expand Up @@ -201,7 +205,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
Expand All @@ -217,21 +221,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
}
}

template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}

template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
Expand All @@ -240,7 +229,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
int seqlen) {
constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
Expand All @@ -263,7 +252,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out,
typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
Expand Down
297 changes: 179 additions & 118 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Large diffs are not rendered by default.

31 changes: 18 additions & 13 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,30 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

std::vector<torch::Tensor> selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states);

at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& conv_state_indices);
const c10::optional<at::Tensor>& bias_, bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_);

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& seq_idx_,
const c10::optional<at::Tensor>& initial_states_,
const c10::optional<at::Tensor>& final_states_out_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation);

#ifndef USE_ROCM
Expand Down
17 changes: 11 additions & 6 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,26 +273,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor!? x) -> Tensor[]");
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor!? final_states_out_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
Expand Down
Loading
Loading