-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] add deepseek-v3 #35926
Open
bzantium
wants to merge
60
commits into
huggingface:main
Choose a base branch
from
bzantium:feature/#35425
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[WIP] add deepseek-v3 #35926
Changes from all commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
704767e
add deepseekv3 modeling
bzantium 737ee3a
Merge branch 'main' into feature/#35425
bzantium fc3a4c7
Merge branch 'main' of https://github.com/bzantium/transformers into …
bzantium 244e793
remove redundant code
bzantium 0968df5
Merge branch 'feature/#35425' of https://github.com/bzantium/transfor…
bzantium 4fb2a80
apply make style
bzantium 6b002e5
apply fix-copies
bzantium 4ec1e88
make format
bzantium 114ab84
add init files
bzantium 779f8d2
rename deepseekv3 into deepseek_v3 based on its model_type
bzantium 22623a3
rename deepseekv3 into deepseek_v3 based on its model_type
bzantium 78b19b0
deepseek-v3 not deepseek_v3
bzantium eb0e3a4
set model_type as deepseek_v3
bzantium 57088cc
use default docs
bzantium 0ef561b
apply make
bzantium 9a75a56
fill type and docstring
bzantium cdf83e4
add rope_config_validation
bzantium 51990b9
use custom DeepseekV3MLP
bzantium f4f0ebd
hold code only for checkpoints congifuration; remove redundant
bzantium 4b72b30
revise rope yarn for DeepSeek variation
bzantium 96562c4
Merge branch 'main' into feature/#35425
bzantium 6792cb5
rename DeepSeek-V3
bzantium 3bf3b32
some refactoring
ArthurZucker 24bc8b2
revise load_hook to work properly; make moe func trainable; use llama…
bzantium 5c0cd91
fix attention forward
bzantium 8e994dd
use -1 for not-changing dim when to use exapnd
bzantium 7405a95
refactor DeepseekV3TopkRouter
bzantium ea3c922
use reshape_for_rope instead of load_hook; revise attention forward f…
bzantium c813268
register pre_hook and hook both
bzantium 4ab2f9e
make style
bzantium c5429ec
use n_shared_experts
bzantium 4df42f0
Update src/transformers/models/deepseek_v3/configuration_deepseek_v3.py
bzantium e0a49ac
Merge branch 'main' of github.com:huggingface/transformers into featu…
dfd9abc
Merge branch 'feature/#35425' of github.com:bzantium/transformers int…
ba21b7c
add test file
bzantium 2270173
Merge branch 'feature/#35425' of https://github.com/bzantium/transfor…
bzantium b5f420b
update modeling_file according to modular file
bzantium 6bd75a9
make style
bzantium 6ccbc66
add mapping for DeepseekV3ForSequenceClassification
bzantium a1c6274
remove aux_loss_alpha
bzantium a80462b
add deepseek_v3 for perf
bzantium dd78f48
add deepseek_v3
bzantium 54481ef
rename test as deepseekv3
bzantium e0f1c2d
use tiny-deepseek-v3
bzantium 23fb756
Merge branch 'main' into feature/#35425
bzantium 5214741
remove DeepseekV3ForSequenceClassification
bzantium 67f1f0c
cache before padding
bzantium f264f80
remote output_router_logits
bzantium d4c6a1b
Revert "remote output_router_logits"
bzantium c7c8d76
remove output_router_logits
bzantium 0b5ff07
Merge branch 'main' into feature/#35425
bzantium ba6f7d4
make e_score_correction_bias as buffer
bzantium d7931b3
skip tests not compatible
bzantium 92bd99c
make style
bzantium 7d81efe
make e_score_correction_bias as buffer
bzantium b33fdb5
use rope_interleave instead of load_hook
bzantium 7f859f8
skip tests not compatible with MLA
bzantium 397ecf3
add doc for rope_interleave
bzantium 2628438
fix typo
bzantium af3d328
remove torch.no_grad for selecting topk
bzantium File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
# DeepSeek-V3 | ||
|
||
## Overview | ||
|
||
The DeepSeek-V3 model was proposed in [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) by DeepSeek-AI Team. | ||
|
||
The abstract from the paper is the following: | ||
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. The model checkpoints are available at https://github.com/deepseek-ai/DeepSeek-V3. | ||
|
||
### Usage tips | ||
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages. | ||
|
||
## DeepseekV3Config | ||
|
||
[[autodoc]] DeepseekV3Config | ||
|
||
## DeepseekV3Model | ||
|
||
[[autodoc]] DeepseekV3Model | ||
- forward | ||
|
||
## DeepseekV3ForCausalLM | ||
|
||
[[autodoc]] DeepseekV3ForCausalLM | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,13 +189,31 @@ def _compute_yarn_parameters( | |
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 | ||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | ||
dim = int(head_dim * partial_rotary_factor) | ||
max_position_embeddings = config.max_position_embeddings | ||
factor = config.rope_scaling["factor"] | ||
attention_factor = config.rope_scaling.get("attention_factor") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know the intricacies of RoPE for DeepSeek, would trust you here. |
||
mscale = config.rope_scaling.get("mscale") | ||
mscale_all_dim = config.rope_scaling.get("mscale_all_dim") | ||
|
||
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a | ||
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two | ||
# values to compute the default attention scaling factor, instead of using `factor`. | ||
if "original_max_position_embeddings" in config.rope_scaling: | ||
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] | ||
factor = config.max_position_embeddings / original_max_position_embeddings | ||
else: | ||
original_max_position_embeddings = config.max_position_embeddings | ||
|
||
def get_mscale(scale, mscale=1): | ||
if scale <= 1: | ||
return 1.0 | ||
return 0.1 * mscale * math.log(scale) + 1.0 | ||
|
||
# Sets the attention factor as suggested in the paper | ||
attention_factor = config.rope_scaling.get("attention_factor") | ||
if attention_factor is None: | ||
attention_factor = 0.1 * math.log(factor) + 1.0 | ||
if mscale and mscale_all_dim: | ||
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) | ||
else: | ||
attention_factor = get_mscale(factor) | ||
|
||
# Optional config options | ||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) | ||
|
@@ -227,15 +245,14 @@ def linear_ramp_factor(min, max, dim): | |
inv_freq_extrapolation = 1.0 / pos_freqs | ||
inv_freq_interpolation = 1.0 / (factor * pos_freqs) | ||
|
||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) | ||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) | ||
|
||
# Get n-dimensional rotational scaling corrected for extrapolation | ||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) | ||
inv_freq = ( | ||
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) | ||
+ inv_freq_extrapolation * inv_freq_extrapolation_factor | ||
) | ||
|
||
return inv_freq, attention_factor | ||
|
||
|
||
|
@@ -425,7 +442,14 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se | |
rope_scaling = config.rope_scaling | ||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" | ||
required_keys = {"rope_type", "factor"} | ||
optional_keys = {"attention_factor", "beta_fast", "beta_slow"} | ||
optional_keys = { | ||
"attention_factor", | ||
"beta_fast", | ||
"beta_slow", | ||
"original_max_position_embeddings", | ||
"mscale", | ||
"mscale_all_dim", | ||
} | ||
received_keys = set(rope_scaling.keys()) | ||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ | |
deberta, | ||
deberta_v2, | ||
decision_transformer, | ||
deepseek_v3, | ||
deformable_detr, | ||
deit, | ||
deprecated, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_deepseek_v3 import * | ||
from .modeling_deepseek_v3 import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A problem is that
dim
is wrong for the DeepSeek model, it has to beconfig.qk_rope_head_dim
. I'd suggest this (sorry, cannot comment on parts of the code not changed):Then, replace line 191 with:
Now, we can call it with the correct
dim
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I take care of this here!
in configuration file,
self.head_dim = qk_rope_head_dim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not recommend that.
self.head_dim
is used in other places, it should remain independent of RoPE. In fact, the correcthead_dim
for DeepSeek models would beqk_nope_head_dim + qk_rope_head_dim
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stuff in
modeling_rope_utils.py
mostly allows to inputdim
, so we should just use that, I think.