From 24bc8b2c279237c6faca54b421e8aade6fb65777 Mon Sep 17 00:00:00 2001 From: ryan u Date: Fri, 31 Jan 2025 22:44:28 +0900 Subject: [PATCH 1/9] revise load_hook to work properly; make moe func trainable; use llama instead of mixtral --- .../deepseek_v3/modeling_deepseek_v3.py | 336 +++++------------- .../models/deepseek_v3/modular_deepseek_v3.py | 125 +++---- 2 files changed, 140 insertions(+), 321 deletions(-) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index f0ee6bf50f83..6c59e03b63f5 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -12,17 +12,11 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - SequenceClassifierOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -203,60 +197,27 @@ def forward(self, hidden_states): orig_shape = hidden_states.shape topk_indices, topk_weights, router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe_infer(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states, router_logits - def moe_infer(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - """ - Perform inference using a Mixture of Experts (MoE) model. - Args: - hidden_states (torch.Tensor): Input hidden states. - topk_indices (torch.Tensor): Indices of the top-k experts for each token. - topk_weights (torch.Tensor): Weights associated with the top-k experts. + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) - Returns: - torch.Tensor: Output of the MoE model. - """ - num_experts = len(self.experts) - batch_size, num_topk = topk_indices.shape - with torch.no_grad(): - # Count the number of tokens assigned to each expert - expert_counts = topk_indices.new_zeros((batch_size, num_experts)) - expert_counts.scatter_(1, topk_indices, 1) - tokens_per_expert = expert_counts.sum(dim=0) - - # Sort tokens by their assigned expert - sorted_indices = topk_indices.view(-1).argsort() - sorted_tokens = hidden_states[sorted_indices // num_topk] - tokens_per_expert = tokens_per_expert.cpu().numpy() - - # Process tokens through their assigned experts - expert_outputs = [] - current_pos = 0 - - for expert_idx, num_tokens in enumerate(tokens_per_expert): - if num_tokens == 0: - continue - - next_pos = current_pos + num_tokens + for expert_idx in range(len(self.experts)): expert = self.experts[expert_idx] - expert_tokens = sorted_tokens[current_pos:next_pos] - expert_outputs.append(expert(expert_tokens)) - current_pos = next_pos + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) - # Combine the outputs from all experts - expert_outputs = torch.cat(expert_outputs, dim=0) if expert_outputs else sorted_tokens.new_empty(0) - - # Reorder the outputs to match the original token sequence - reordered_outputs = torch.empty_like(expert_outputs) - reordered_outputs[sorted_indices] = expert_outputs - - # Reshape and apply the expert weights - reordered_outputs = reordered_outputs.view(batch_size, num_topk, -1).type(topk_weights.dtype) - moe_output = torch.matmul(topk_weights.unsqueeze(1), reordered_outputs) - moe_output = moe_output.sum(dim=1).type(hidden_states.dtype) - return moe_output + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + return final_hidden_states.type(hidden_states.dtype) def rotate_half(x): @@ -642,17 +603,6 @@ def _init_weights(self, module): """ -def permute_for_rope(input_tensor, n_heads, dim1, dim2): - """ - When you go from the complex ROPE formulation to sin and cos one, you need - to permute the query and key weights (to avoid doing it on the fly) - """ - input_tensor = input_tensor.reshape(dim1, dim2) - input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) - input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2) - return input_tensor - - @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", DEEPSEEK_V3_START_DOCSTRING, @@ -694,48 +644,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -751,9 +696,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -765,7 +709,6 @@ def forward( position_ids, past_key_values, output_attentions, - output_router_logits, use_cache, cache_position, position_embeddings, @@ -777,7 +720,6 @@ def forward( position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, - output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -789,21 +731,17 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - output = MoeModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, - router_logits=all_router_logits, ) return output if return_dict else output.to_tuple() @@ -816,15 +754,7 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of DeepseekV3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -833,30 +763,21 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: + if using_static_cache: target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -873,8 +794,6 @@ def _update_causal_mask( device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, ) if ( @@ -886,6 +805,7 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -899,8 +819,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, - config: DeepseekV3Config, - past_key_values: Cache, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -908,11 +827,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -921,10 +842,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. - config (`DeepseekV3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -934,128 +851,68 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) + return causal_mask def load_hook(self, state_dict, prefix, *args): """ - Weights have to be permutted for correct rope formulation. We can't do this in the weights - as every other framework already uses the `Llama` orginal function (which is copyrighted btw). + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). And I am not even sure it's better.... anyways end of my rant """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + for k in state_dict: if "q_b_proj." in k: - weight = state_dict.pop(k[: self.qk_nope_head_dim]) - if "k_b_proj." in k: - weight = state_dict.pop(k[self.qk_nope_head_dim :]) - state_dict[k] = permute_for_rope(weight, weight.shape[0], weight.shape[1], weight.shape[2]) + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.q_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -def load_balancing_loss_func( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], - num_experts: Optional[int] = None, - top_k=2, - attention_mask: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, int]: - r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. - - Args: - gate_logits: - Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of - shape [batch_size X sequence_length, num_experts]. - num_experts: - Number of experts - top_k: - The number of experts to route per-token, can be also interpreted as the `top-k` routing - parameter. - attention_mask (`torch.Tensor`, *optional*): - The attention_mask used in forward function - shape [batch_size X sequence_length] if not None. - - Returns: - The auxiliary loss. - """ - if gate_logits is None or not isinstance(gate_logits, tuple): - return 0 - - if isinstance(gate_logits, tuple): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) - - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - - expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) - - if attention_mask is None: - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.mean(routing_weights, dim=0) - else: - batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) - - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask - expert_attention_mask = ( - attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) - .reshape(-1, top_k, num_experts) - .to(compute_device) - ) - - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) - - # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert - router_per_expert_attention_mask = ( - attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) - .to(compute_device) - ) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) - - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) - return overall_loss * num_experts - - class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1065,9 +922,6 @@ def __init__(self, config): self.model = DeepseekV3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.router_aux_loss_coef = config.router_aux_loss_coef - self.num_experts = config.num_local_experts - self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() @@ -1098,13 +952,12 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, @@ -1131,8 +984,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - >>> model = DeepseekV3ForCausalLM.from_pretrained("mistralai/DeepseekV3-8x7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/DeepseekV3-8x7B-v0.1") + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1142,12 +995,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1163,7 +1011,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, **kwargs, @@ -1176,33 +1023,18 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( + return CausalLMOutputWithPast( loss=loss, - aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - router_logits=outputs.router_logits, ) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 9d76a96b6a87..db3bb276742a 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -13,14 +13,15 @@ from ...processing_utils import Unpack from ...utils import logging from ..llama.modeling_llama import ( + LlamaForCausalLM, LlamaForSequenceClassification, + LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) -from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from .configuration_deepseek_v3 import DeepseekV3Config @@ -122,60 +123,27 @@ def forward(self, hidden_states): orig_shape = hidden_states.shape topk_indices, topk_weights, router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe_infer(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states, router_logits - def moe_infer(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - """ - Perform inference using a Mixture of Experts (MoE) model. - Args: - hidden_states (torch.Tensor): Input hidden states. - topk_indices (torch.Tensor): Indices of the top-k experts for each token. - topk_weights (torch.Tensor): Weights associated with the top-k experts. - - Returns: - torch.Tensor: Output of the MoE model. - """ - num_experts = len(self.experts) - batch_size, num_topk = topk_indices.shape - with torch.no_grad(): - # Count the number of tokens assigned to each expert - expert_counts = topk_indices.new_zeros((batch_size, num_experts)) - expert_counts.scatter_(1, topk_indices, 1) - tokens_per_expert = expert_counts.sum(dim=0) - - # Sort tokens by their assigned expert - sorted_indices = topk_indices.view(-1).argsort() - sorted_tokens = hidden_states[sorted_indices // num_topk] - tokens_per_expert = tokens_per_expert.cpu().numpy() - - # Process tokens through their assigned experts - expert_outputs = [] - current_pos = 0 - - for expert_idx, num_tokens in enumerate(tokens_per_expert): - if num_tokens == 0: - continue - - next_pos = current_pos + num_tokens - expert = self.experts[expert_idx] - expert_tokens = sorted_tokens[current_pos:next_pos] - expert_outputs.append(expert(expert_tokens)) - current_pos = next_pos - - # Combine the outputs from all experts - expert_outputs = torch.cat(expert_outputs, dim=0) if expert_outputs else sorted_tokens.new_empty(0) + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) - # Reorder the outputs to match the original token sequence - reordered_outputs = torch.empty_like(expert_outputs) - reordered_outputs[sorted_indices] = expert_outputs + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) - # Reshape and apply the expert weights - reordered_outputs = reordered_outputs.view(batch_size, num_topk, -1).type(topk_weights.dtype) - moe_output = torch.matmul(topk_weights.unsqueeze(1), reordered_outputs) - moe_output = moe_output.sum(dim=1).type(hidden_states.dtype) - return moe_output + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + return final_hidden_states.type(hidden_states.dtype) class DeepseekV3Attention(nn.Module): @@ -365,18 +333,7 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): pass -def permute_for_rope(input_tensor, n_heads, dim1, dim2): - """ - When you go from the complex ROPE formulation to sin and cos one, you need - to permute the query and key weights (to avoid doing it on the fly) - """ - input_tensor = input_tensor.reshape(dim1, dim2) - input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) - input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2) - return input_tensor - - -class DeepseekV3Model(MixtralModel): +class DeepseekV3Model(LlamaModel): def __init__(self, config): super().__init__(config) self._register_load_state_dict_pre_hook(self.load_hook) @@ -384,19 +341,49 @@ def __init__(self, config): def load_hook(self, state_dict, prefix, *args): """ - Weights have to be permutted for correct rope formulation. We can't do this in the weights - as every other framework already uses the `Llama` orginal function (which is copyrighted btw). + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). And I am not even sure it's better.... anyways end of my rant """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + for k in state_dict: if "q_b_proj." in k: - weight = state_dict.pop(k[: self.qk_nope_head_dim]) - if "k_b_proj." in k: - weight = state_dict.pop(k[self.qk_nope_head_dim :]) - state_dict[k] = permute_for_rope(weight, weight.shape[0], weight.shape[1], weight.shape[2]) + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.q_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) -class DeepseekV3ForCausalLM(MixtralForCausalLM): +class DeepseekV3ForCausalLM(LlamaForCausalLM): pass From 5c0cd917b02d5ba9d4fb373e01f3fe24b36c40ff Mon Sep 17 00:00:00 2001 From: ryan u Date: Fri, 31 Jan 2025 23:37:18 +0900 Subject: [PATCH 2/9] fix attention forward --- .../models/deepseek_v3/modeling_deepseek_v3.py | 7 ++++--- src/transformers/models/deepseek_v3/modular_deepseek_v3.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 6c59e03b63f5..0c5c510d2277 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -361,7 +361,7 @@ def forward( batch_size, seq_length = input_shape q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) - q_rot, q_pass = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -373,9 +373,10 @@ def forward( cos, sin = position_embeddings q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(batch_size, self.num_heads, seq_length, self.qk_rope_head_dim) - query_states = torch.cat(q_rot, q_pass, dim=-1) - key_states = torch.cat(k_rot, k_pass, dim=-1) + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index db3bb276742a..5ac409e3b7f0 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -209,7 +209,7 @@ def forward( batch_size, seq_length = input_shape q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) - q_rot, q_pass = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -221,9 +221,10 @@ def forward( cos, sin = position_embeddings q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(batch_size, self.num_heads, seq_length, self.qk_rope_head_dim) - query_states = torch.cat(q_rot, q_pass, dim=-1) - key_states = torch.cat(k_rot, k_pass, dim=-1) + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) From 8e994dd83a68f20081b60221d35a7e3b2c43ab61 Mon Sep 17 00:00:00 2001 From: ryan u Date: Sat, 1 Feb 2025 14:01:23 +0900 Subject: [PATCH 3/9] use -1 for not-changing dim when to use exapnd --- src/transformers/models/deepseek_v3/modeling_deepseek_v3.py | 3 +-- src/transformers/models/deepseek_v3/modular_deepseek_v3.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 0c5c510d2277..af55c5f668a4 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -358,7 +358,6 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, self.num_heads, -1) - batch_size, seq_length = input_shape q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -373,7 +372,7 @@ def forward( cos, sin = position_embeddings q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) - k_rot = k_rot.expand(batch_size, self.num_heads, seq_length, self.qk_rope_head_dim) + k_rot = k_rot.expand(-1, self.num_heads, -1, -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 5ac409e3b7f0..6966a0de2132 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -206,7 +206,6 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, self.num_heads, -1) - batch_size, seq_length = input_shape q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -221,7 +220,7 @@ def forward( cos, sin = position_embeddings q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) - k_rot = k_rot.expand(batch_size, self.num_heads, seq_length, self.qk_rope_head_dim) + k_rot = k_rot.expand(-1, self.num_heads, -1, -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) From 7405a95f71f26ad68fe7000367b77507f621ee66 Mon Sep 17 00:00:00 2001 From: ryan u Date: Sat, 1 Feb 2025 17:08:00 +0900 Subject: [PATCH 4/9] refactor DeepseekV3TopkRouter --- .../deepseek_v3/modeling_deepseek_v3.py | 35 +++++++++++-------- .../models/deepseek_v3/modular_deepseek_v3.py | 35 +++++++++++-------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index af55c5f668a4..781d7756de19 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -142,37 +142,42 @@ def __init__(self, config): self.routed_scaling_factor = config.routed_scaling_factor self.n_group = config.n_group self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - batch_size, seq_length = hidden_states.shape[:-1] hidden_states = hidden_states.view(-1, self.config.hidden_size) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights, router_logits + + @torch.no_grad() + def get_topk_indices(self, scores): scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) .topk(2, dim=-1)[0] .sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) .reshape(-1, self.n_routed_experts) - ) # [n, e] - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) - topk_weights = scores.gather(1, topk_indices) - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor - return topk_indices, topk_weights, router_logits + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices class DeepseekV3MoE(nn.Module): diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 6966a0de2132..83ad25684814 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -68,37 +68,42 @@ def __init__(self, config): self.routed_scaling_factor = config.routed_scaling_factor self.n_group = config.n_group self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - batch_size, seq_length = hidden_states.shape[:-1] hidden_states = hidden_states.view(-1, self.config.hidden_size) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights, router_logits + + @torch.no_grad() + def get_topk_indices(self, scores): scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) .topk(2, dim=-1)[0] .sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) .reshape(-1, self.n_routed_experts) - ) # [n, e] - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) - topk_weights = scores.gather(1, topk_indices) - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor - return topk_indices, topk_weights, router_logits + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices class DeepseekV3MoE(nn.Module): From ea3c9225546c2ac08622487e6b397cceacc93ff0 Mon Sep 17 00:00:00 2001 From: ryan u Date: Mon, 3 Feb 2025 16:59:37 +0900 Subject: [PATCH 5/9] use reshape_for_rope instead of load_hook; revise attention forward for TP; rename q_head_dim with qk_head_dim --- .../deepseek_v3/configuration_deepseek_v3.py | 5 +- .../deepseek_v3/modeling_deepseek_v3.py | 83 ++++++------------ .../models/deepseek_v3/modular_deepseek_v3.py | 85 ++++++------------- 3 files changed, 52 insertions(+), 121 deletions(-) diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index 8d1c56343638..5943aa12648e 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -138,6 +138,9 @@ class DeepseekV3Config(PretrainedConfig): "layers.*.gate_proj": "colwise", "layers.*.up_proj": "colwise", "layers.*.down_proj": "rowwise", + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_b_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", } def __init__( @@ -194,7 +197,7 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim - self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.head_dim = qk_rope_head_dim self.n_group = n_group self.topk_group = topk_group diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 781d7756de19..ef6be2f6add5 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -319,12 +319,12 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.q_head_dim + self.qk_head_dim = config.qk_head_dim self.is_causal = True self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, @@ -334,7 +334,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( self.kv_lora_rank, - self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, ) @@ -344,7 +344,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): bias=config.attention_bias, ) - self.scaling = self.q_head_dim ** (-0.5) + self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_scaling is not None: mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] @@ -361,29 +361,31 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, self.num_heads, -1) + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) - q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2) + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings + q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot) q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) - k_rot = k_rot.expand(-1, self.num_heads, -1, -1) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) - if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -411,13 +413,20 @@ def forward( **kwargs, ) - if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + def reshape_for_rope(self, q, k): + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) + return q, k + class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): @@ -620,7 +629,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): config: DeepseekV3Config """ - def __init__(self, config): + def __init__(self, config: DeepseekV3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -632,7 +641,6 @@ def __init__(self, config): self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) self.gradient_checkpointing = False - self._register_load_state_dict_pre_hook(self.load_hook) # Initialize weights and apply final processing self.post_init() @@ -871,49 +879,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - def load_hook(self, state_dict, prefix, *args): - """ - Weights have to be permuted for correct rope formulation. We can't do this in the weights - as every other framework already uses the `Llama` original function (which is copyrighted btw). - And I am not even sure it's better.... anyways end of my rant - """ - - def permute_for_rope(input_tensor): - """ - When you go from the complex ROPE formulation to sin and cos one, you need - to permute the query and key weights (to avoid doing it on the fly) - """ - n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] - input_tensor = input_tensor.reshape(n_heads * dim1, dim2) - input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) - input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) - return input_tensor - - def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): - weight = state_dict[key] - weight = weight.view(num_heads, head_dim, -1) - weight_rot = weight[:, -rope_dim:] - weight_rot = permute_for_rope(weight_rot) - weight[:, -rope_dim:] = weight_rot - weight = weight.view(-1, weight.shape[-1]) - state_dict[key] = weight - - for k in state_dict: - if "q_b_proj." in k: - permute_layer_for_rope( - k, - num_heads=self.config.num_attention_heads, - head_dim=self.config.q_head_dim, - rope_dim=self.config.qk_rope_head_dim, - ) - if "kv_a_proj_with_mqa." in k: - permute_layer_for_rope( - k, - num_heads=1, - head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, - rope_dim=self.config.qk_rope_head_dim, - ) - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 83ad25684814..55d1ac5c691f 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -167,12 +167,12 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.q_head_dim + self.qk_head_dim = config.qk_head_dim self.is_causal = True self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, @@ -182,7 +182,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( self.kv_lora_rank, - self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, ) @@ -192,7 +192,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): bias=config.attention_bias, ) - self.scaling = self.q_head_dim ** (-0.5) + self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_scaling is not None: mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] @@ -209,29 +209,31 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, self.num_heads, -1) + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) - q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2) + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings + q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot) q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) - k_rot = k_rot.expand(-1, self.num_heads, -1, -1) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) - if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -259,13 +261,20 @@ def forward( **kwargs, ) - if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + def reshape_for_rope(self, q, k): + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) + return q, k + class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): @@ -339,53 +348,7 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): class DeepseekV3Model(LlamaModel): - def __init__(self, config): - super().__init__(config) - self._register_load_state_dict_pre_hook(self.load_hook) - self.post_init() - - def load_hook(self, state_dict, prefix, *args): - """ - Weights have to be permuted for correct rope formulation. We can't do this in the weights - as every other framework already uses the `Llama` original function (which is copyrighted btw). - And I am not even sure it's better.... anyways end of my rant - """ - - def permute_for_rope(input_tensor): - """ - When you go from the complex ROPE formulation to sin and cos one, you need - to permute the query and key weights (to avoid doing it on the fly) - """ - n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] - input_tensor = input_tensor.reshape(n_heads * dim1, dim2) - input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) - input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) - return input_tensor - - def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): - weight = state_dict[key] - weight = weight.view(num_heads, head_dim, -1) - weight_rot = weight[:, -rope_dim:] - weight_rot = permute_for_rope(weight_rot) - weight[:, -rope_dim:] = weight_rot - weight = weight.view(-1, weight.shape[-1]) - state_dict[key] = weight - - for k in state_dict: - if "q_b_proj." in k: - permute_layer_for_rope( - k, - num_heads=self.config.num_attention_heads, - head_dim=self.config.q_head_dim, - rope_dim=self.config.qk_rope_head_dim, - ) - if "kv_a_proj_with_mqa." in k: - permute_layer_for_rope( - k, - num_heads=1, - head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, - rope_dim=self.config.qk_rope_head_dim, - ) + pass class DeepseekV3ForCausalLM(LlamaForCausalLM): From c8132687dfa51416e4f84f017034948bee79260c Mon Sep 17 00:00:00 2001 From: ryan u Date: Mon, 3 Feb 2025 19:05:15 +0900 Subject: [PATCH 6/9] register pre_hook and hook both --- .../deepseek_v3/configuration_deepseek_v3.py | 6 +- .../deepseek_v3/modeling_deepseek_v3.py | 58 +++++++++++++++--- .../models/deepseek_v3/modular_deepseek_v3.py | 60 ++++++++++++++++--- 3 files changed, 103 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index 5943aa12648e..c0b412dde023 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -135,12 +135,12 @@ class DeepseekV3Config(PretrainedConfig): keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `DeepseekV3Model` base_model_tp_plan = { - "layers.*.gate_proj": "colwise", - "layers.*.up_proj": "colwise", - "layers.*.down_proj": "rowwise", "layers.*.self_attn.q_b_proj": "colwise", "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", + "layers.*.gate_proj": "colwise", + "layers.*.up_proj": "colwise", + "layers.*.down_proj": "rowwise", } def __init__( diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index ef6be2f6add5..3757a26cacbc 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -377,7 +377,6 @@ def forward( k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings - q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot) q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) @@ -420,13 +419,6 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, attn_weights - def reshape_for_rope(self, q, k): - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - return q, k - class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): @@ -629,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): config: DeepseekV3Config """ - def __init__(self, config: DeepseekV3Config): + def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -641,6 +633,8 @@ def __init__(self, config: DeepseekV3Config): self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) self.gradient_checkpointing = False + self._register_load_state_dict_pre_hook(self.load_pre_hook) + self._register_state_dict_hook(self.load_hook) # Initialize weights and apply final processing self.post_init() @@ -879,6 +873,52 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask + def load_pre_hook(self, state_dict, prefix, *args): + """ + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). + And I am not even sure it's better.... anyways end of my rant + """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + + for k in state_dict: + if "q_b_proj." in k: + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.qk_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + + def load_hook(self, module, state_dict, prefix, *args): + self.load_pre_hook(state_dict, prefix, *args) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 55d1ac5c691f..e41ded87f2bd 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -225,7 +225,6 @@ def forward( k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings - q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot) q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) @@ -268,13 +267,6 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, attn_weights - def reshape_for_rope(self, q, k): - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - return q, k - class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): @@ -348,7 +340,57 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): class DeepseekV3Model(LlamaModel): - pass + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self.load_pre_hook) + self._register_state_dict_hook(self.load_hook) + self.post_init() + + def load_pre_hook(self, state_dict, prefix, *args): + """ + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). + And I am not even sure it's better.... anyways end of my rant + """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + + for k in state_dict: + if "q_b_proj." in k: + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.qk_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + + def load_hook(self, module, state_dict, prefix, *args): + self.load_pre_hook(state_dict, prefix, *args) class DeepseekV3ForCausalLM(LlamaForCausalLM): From 4ab2f9e8f1a2922da699535afbd53be9593fb418 Mon Sep 17 00:00:00 2001 From: ryan u Date: Mon, 3 Feb 2025 19:06:07 +0900 Subject: [PATCH 7/9] make style --- src/transformers/models/deepseek_v3/modular_deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index e41ded87f2bd..98173ac9f237 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -388,7 +388,7 @@ def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, rope_dim=self.config.qk_rope_head_dim, ) - + def load_hook(self, module, state_dict, prefix, *args): self.load_pre_hook(state_dict, prefix, *args) From c5429ec7e58ec7478cf2e1a3b5d17cd5dff5dd3a Mon Sep 17 00:00:00 2001 From: ryan u Date: Mon, 10 Feb 2025 11:05:49 +0900 Subject: [PATCH 8/9] use n_shared_experts --- src/transformers/models/deepseek_v3/modeling_deepseek_v3.py | 4 +++- src/transformers/models/deepseek_v3/modular_deepseek_v3.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 3757a26cacbc..b7c34226a0dc 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -195,7 +195,9 @@ def __init__(self, config): ] ) self.gate = DeepseekV3TopkRouter(config) - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size) + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) def forward(self, hidden_states): residuals = hidden_states diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 98173ac9f237..af83524adca2 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -121,7 +121,9 @@ def __init__(self, config): ] ) self.gate = DeepseekV3TopkRouter(config) - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size) + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) def forward(self, hidden_states): residuals = hidden_states From 4df42f0b59619b63d693f272dc04d1c00b0be1b0 Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Fri, 14 Feb 2025 16:47:42 +0900 Subject: [PATCH 9/9] Update src/transformers/models/deepseek_v3/configuration_deepseek_v3.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/deepseek_v3/configuration_deepseek_v3.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index c0b412dde023..e459b47b4bc4 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -142,7 +142,11 @@ class DeepseekV3Config(PretrainedConfig): "layers.*.up_proj": "colwise", "layers.*.down_proj": "rowwise", } - + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, vocab_size=129280,