Skip to content

Commit

Permalink
Mixtral FastGen Support (#4828)
Browse files Browse the repository at this point in the history
Adds support for Mixtral with FastGen. Key features implemented:

1. Top-2 MoE support
2. Better support for RoPE thetas
3. The mistral model implementation

---------

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
cmikeh2 and mrwyattii authored Dec 21, 2023
1 parent 1864391 commit c00388a
Show file tree
Hide file tree
Showing 57 changed files with 1,193 additions and 340 deletions.
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool:
# We need to download the checkpoint files from HF
if model_has_safetensors(self.model_name_or_path):
# Prioritize downloading safetensors if they are available
allow_patterns = ["*.safetensors", "*.json", "*.pt"]
allow_patterns = ["*.safetensors", "*.json"]
else:
# Fallback to bin files when safetensors are not present
allow_patterns = ["*.bin", "*.json", "*.pt"]
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OPTPolicy,
Llama2Policy,
MistralPolicy,
MixtralPolicy,
FalconPolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
Expand Down Expand Up @@ -105,6 +106,12 @@ def build_hf_engine(path: str,
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mixtral":
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \
f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}"
policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "falcon":
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/kernels/ragged_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from .logits_gather import *
from .moe_gather import *
from .moe_scatter import *
from .top_1_gating import *
from .top_k_gating import *
15 changes: 15 additions & 0 deletions deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#define TOP_K_SWITCH(N_TOP_K, ...) \
[&] { \
if (1 == N_TOP_K) { \
constexpr int CONST_TOP_K = 1; \
__VA_ARGS__(); \
} else if (2 == N_TOP_K) { \
constexpr int CONST_TOP_K = 2; \
__VA_ARGS__(); \
} \
}()
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
(C_TYPE*)k.data_ptr(), \
(C_TYPE*)v.data_ptr(), \
(C_TYPE*)inv_freq_ptr, \
theta_base, \
batch_wrapper, \
qkv_stride, \
kv_cache_stride, \
Expand Down Expand Up @@ -51,6 +52,8 @@ void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
TORCH_CHECK(n_tokens == k.size(0));
TORCH_CHECK(n_tokens == v.size(0));

const float theta_base = 0.f;

// Dimensions
const int32_t block_size = kv_cache.size(1);
const int32_t n_kv_heads = kv_cache.size(3);
Expand Down Expand Up @@ -91,6 +94,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
const float theta_base,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
T* k,
T* v,
const T* inv_freq,
const float theta_base,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
Expand Down Expand Up @@ -114,7 +115,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
// Conversion to T and back means that both branches of this if statement
// will produce the same results if using the same algo for producing the
// freqs.
T trunc_freq = conversion::to<T>(1.0 / powf(10000.0, inv_freq_flt));
T trunc_freq = conversion::to<T>(1.0 / powf(theta_base, inv_freq_flt));
inv_freq_flt = conversion::to<float>(trunc_freq) * (float)global_token_idx;
}

Expand Down Expand Up @@ -158,7 +159,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
} else {
inv_freq_flt =
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx;
inv_freq_flt = 1.0 / powf(theta_base, inv_freq_flt) * (float)global_token_idx;
}

float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
Expand Down Expand Up @@ -186,6 +187,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
k, \
v, \
inv_freq, \
theta_base, \
batch_desc, \
qkv_stride, \
kv_cache_stride, \
Expand All @@ -198,6 +200,7 @@ void launch_kv_rotary_kernel(T* kv_cache,
T* k,
T* v,
T* inv_freq,
const float theta_base,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
Expand Down Expand Up @@ -245,6 +248,7 @@ void launch_kv_rotary_kernel(T* kv_cache,
TYPE * k, \
TYPE * v, \
TYPE * inv_freq, \
const float theta_base, \
const BatchWrapperCPP batch_desc, \
const int qkv_stride, \
const int kv_cache_stride, \
Expand All @@ -262,10 +266,20 @@ INSTANTIATE_KV_ROTARY_KERNEL(__half)
INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16)
#endif

#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \
if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \
kv_rotary_pos_kernel<T, Q_RATIO, HEAD_SIZE, false><<<grid, block, 0, stream>>>( \
kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0);
#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \
if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \
kv_rotary_pos_kernel<T, Q_RATIO, HEAD_SIZE, false> \
<<<grid, block, 0, stream>>>(kv_cache, \
q, \
k, \
v, \
nullptr, \
0.f, \
batch_desc, \
qkv_stride, \
kv_cache_stride, \
v_offset, \
0);

template <typename T>
void launch_kv_copy_kernel(T* kv_cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void launch_kv_rotary_kernel(T* kv_cache,
T* k,
T* v,
T* inv_freq,
const float theta_base,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
const int kv_cache_stride,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
const float theta_base,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ class BlockedRotaryEmbeddings(DSKernelBase):
supported_head_sizes = [64, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
def __init__(self,
head_size: int,
n_q_heads: int,
n_kv_heads: int,
dtype: torch.dtype,
theta_base: float = 10000.0) -> None:
"""
Args:
head_size: The size of the attention head.
Expand Down Expand Up @@ -51,6 +56,7 @@ def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch
self.head_size = head_size
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
self.theta_base = theta_base

def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None:
"""
Expand All @@ -66,5 +72,5 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]

self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(),
ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())
self.kernel(kv_cache, q, k, v, self.theta_base, ragged_batch.batch_metadata_buffer(),
ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
n_channels, \
n_experts, \
n_tokens, \
n_top_k, \
normalize_scales, \
at::cuda::getCurrentCUDAStream()); \
return; \
}
Expand All @@ -27,17 +29,21 @@ void moe_gather(torch::Tensor& layer_output,
const torch::Tensor& moe_output,
const torch::Tensor& scores,
const torch::Tensor& mapped_slots,
const torch::Tensor& expert_count)
const torch::Tensor& expert_count,
const bool normalize_scales)
{
const int32_t n_channels = layer_output.size(1);
const int32_t n_experts = expert_count.size(0);
const int32_t n_tokens = layer_output.size(0);
const int32_t n_top_k = mapped_slots.size(1);

TORCH_CHECK(moe_output.size(0) == n_tokens);
TORCH_CHECK(moe_output.size(0) == n_tokens * n_top_k);
TORCH_CHECK(moe_output.size(1) == n_channels);
TORCH_CHECK(scores.size(0) == n_tokens);
TORCH_CHECK(mapped_slots.size(0) == n_tokens);

TORCH_CHECK(scores.size(1) == n_top_k);

TORCH_CHECK(layer_output.scalar_type() == moe_output.scalar_type());
TORCH_CHECK(scores.scalar_type() == torch::kFloat32);
TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32);
Expand Down
Loading

0 comments on commit c00388a

Please sign in to comment.