Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add deepseek-v3 #35926

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open

Conversation

bzantium
Copy link
Contributor

@bzantium bzantium commented Jan 28, 2025

What does this PR do?

This PR adds the codes for the DeepSeekV3.
code relies heavily on original remote code.

resolved: #35425

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

to: @ArthurZucker

@Rocketknight1
Copy link
Member

Hi @bzantium, this looks great so far! We'll need added tests for the model + a green CI, and then feel free to ping me to assign a reviewer, or if you have any problems with the port.

@bzantium bzantium changed the title [WIP] add deepseekv3 [WIP] add deepseek-v3 Jan 29, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultra kudos! It's super nice
Mostly missing tests, here you can use a similar approach to the gemma2 tests, which use inheritance!

@cuichenx
Copy link

@bzantium Thanks for the amazing work! I was wondering if you were able to train V3 with FSDP? If so how many GPUs did you need? Thanks!

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 29, 2025

One big thing would be TP support, the base_tp_plan would probably need to be updated to make sure each mlp's gat up down have the correct order, unless the direct usage of dist remove this need

@casper-hansen
Copy link

This is great work and I'm looking forward to try it out. For multi-token prediction, is this planned to be implemented in this PR via the num_nextn_predict_layers attribute in the config?

@bzantium
Copy link
Contributor Author

bzantium commented Jan 30, 2025

Thanks for the comments in detail; following your comments, I revised code quite a lot and fixed some mismatch between original code and this PR. I checked the outputs from both are the same. I think now I can add test codes. For TP support, I think they can be applied only for mlp layer but not for self_attn because they have functions like split on the hidden_dim. I added as following:

    base_model_tp_plan = {
        "layers.*.gate_proj": "colwise",
        "layers.*.up_proj": "colwise",
        "layers.*.down_proj": "rowwise",
    }

to: @ArthurZucker

@mseeger
Copy link

mseeger commented Feb 18, 2025

OK, I sent a PR against the branch with some fixes.

One reason for the remaining failures could be that the head size of K and V tensors is different than normal, because you have these additional qk_rope_head_dim entries. May mean one has to generalize some of the common tests.

@fungaren fungaren mentioned this pull request Feb 18, 2025
3 tasks
@bzantium
Copy link
Contributor Author

bzantium commented Feb 18, 2025

Based on the test logs, I found two reasons:

  1. as @mseeger said, some tests are not compatible for multi latent attention (checking head_dim) so I skipped some tests (maybe need to skip more).
  2. because of load_pre_hook, load and save tensors become different so we need to add save_hook or remove load_hook.

@ArthurZucker
Copy link
Collaborator

Ah interesting, that is indeed an issue (load / save) to be careful about.

@ArthurZucker
Copy link
Collaborator

Okay, I'll give it another shot in a bit!

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this line with:

        if self.rope_type != "yarn":
            # If we pass `config` to `rope_init_fn`, the dimension is set to a wrong
            # value (for standard multi-head attention):
            self._rope_kwargs = dict(
                base=config.rope_theta, dim=config.qk_rope_head_dim,
            )
            if config.rope_scaling is not None:
                self._rope_kwargs["factor"] = config.rope_scaling["factor"]
            if config.max_position_embeddings is not None:
                self._rope_kwargs["max_position_embeddings"] = config.max_position_embeddings
        else:
            # TODO: `_compute_yarn_parameters` requires `config` and will lead
            # to wrong dimension in this case
            self._rope_kwargs = {"config": config}
        inv_freq, self.attention_scaling = self.rope_init_fn(device=device, **self._rope_kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revised src/transformers/modeling_rope_utils.py script as well for rope. could you check this file?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revise again:

        if self.rope_type != "yarn":
            # If we pass `config` to `rope_init_fn`, the dimension is set to a wrong
            # value (for standard multi-head attention):
            self._rope_kwargs = dict(
                base=config.rope_theta, dim=config.qk_rope_head_dim,
            )
            if config.rope_scaling is not None:
                self._rope_kwargs["factor"] = config.rope_scaling["factor"]
            if config.max_position_embeddings is not None:
                self._rope_kwargs["max_position_embeddings"] = config.max_position_embeddings
        else:
            # We can pass `config` and `dim` in this case:
            self._rope_kwargs = {"config": config, "dim": dim=config.qk_rope_head_dim}
        inv_freq, self.attention_scaling = self.rope_init_fn(device=device, **self._rope_kwargs)

"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this line with:

inv_freq, self.attention_scaling = self.rope_init_fn(device=device, seq_len=seq_len, **self._rope_kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment above.

causal_mask,
position_ids,
past_key_values,
output_attentions,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert this line after output_attentions:

False,  # output_router_logits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed output_router_logits at this time to firstly fix essence modeling problem. thanks for notifying me what I missed.

use_cache,
cache_position,
position_embeddings,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert this line:

**flash_attn_kwargs,

Copy link

@mseeger mseeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments with fixes

@bzantium
Copy link
Contributor Author

bzantium commented Feb 19, 2025

Left some comments with fixes

Thanks for the comments! you can give suggestions directly like below! (not just text)
click Add a suggestion button after dragging where to fix, and replace original code with your code.

image

Also, you better give a suggestion on modular_deepseek_v3.py because modeling_deepseek_v3.py is automatically generated using modular file.

@bzantium
Copy link
Contributor Author

I found more reasons to fail:

1 failed because `AssertionError: nan not found in [0.0, 1.0] ` -> Parameter layers.2.mlp.gate.weight of model <class 'transformers.models.deepseek_v3.modeling_deepseek_v3.DeepseekV3Model'> seems not properly initialized
   1 failed because `AssertionError: False is not true ` -> model.layers.2.mlp.experts.0.gate_proj.weight in DeepseekV3ForCausalLM has no gradient!
   2 failed because `AssertionError: DeepseekV3Model: Tensor layers.2.mlp.gate.weight` -> Tensor-likes are not close!
   2 failed because `AssertionError: False is not true ` -> model.layers.2.mlp.experts.4.gate_proj.weight in DeepseekV3ForCausalLM has no gradient!
  1. first one is because weight is initialized with torch.empty.
    self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))

  2. "has no gradient" problem is maybe because of topk selection which is key of the moe. I think this is mainly because how I implement get_topk_indices for router.

@@ -189,13 +189,31 @@ def _compute_yarn_parameters(
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A problem is that dim is wrong for the DeepSeek model, it has to be config.qk_rope_head_dim. I'd suggest this (sorry, cannot comment on parts of the code not changed):

def _compute_yarn_parameters(
    config: PretrainedConfig,
    device: "torch.device",
    seq_len: Optional[int] = None,
    dim: Optional[int] = None,
    **rope_kwargs
) -> Tuple["torch.Tensor", float]:

Then, replace line 191 with:

if dim is None:
    dim = int(head_dim * partial_rotary_factor)

Now, we can call it with the correct dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take care of this here!

in configuration file,
self.head_dim = qk_rope_head_dim

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not recommend that. self.head_dim is used in other places, it should remain independent of RoPE. In fact, the correct head_dim for DeepSeek models would be qk_nope_head_dim + qk_rope_head_dim.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stuff in modeling_rope_utils.py mostly allows to input dim, so we should just use that, I think.

factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the intricacies of RoPE for DeepSeek, would trust you here.

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revise again:

        if self.rope_type != "yarn":
            # If we pass `config` to `rope_init_fn`, the dimension is set to a wrong
            # value (for standard multi-head attention):
            self._rope_kwargs = dict(
                base=config.rope_theta, dim=config.qk_rope_head_dim,
            )
            if config.rope_scaling is not None:
                self._rope_kwargs["factor"] = config.rope_scaling["factor"]
            if config.max_position_embeddings is not None:
                self._rope_kwargs["max_position_embeddings"] = config.max_position_embeddings
        else:
            # We can pass `config` and `dim` in this case:
            self._rope_kwargs = {"config": config, "dim": dim=config.qk_rope_head_dim}
        inv_freq, self.attention_scaling = self.rope_init_fn(device=device, **self._rope_kwargs)

pass


class DeepseekV3Model(LlamaModel):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so IF modeling_deepseek_v3.py is indeed created from modular_deepseek_v3.py (and I am not sure about this), then this. here will not work at all, right? This would simply be the LlamaModel.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One would at least have to copy the __init__ and make sure that DeepseekV3DecoderLayer is used. But we also need to use DeepseekV3RotaryEmbedding, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not just inheritance as python. you can check how modeling_deepseek_v3.py look like.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still learning. Can you point me to where modeling is created from modular? My impression is one can change them independently, and there is some automatic check for differences.

I also understand the tole of the # Copied from ... comments, one can even specify some transformation rules at the end. But this is not the case for DeepseekV3Model. @ArthurZucker ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and even if this really works, so that I get the classes called DeepseekXYZ in modeling_deepseek_v3.py automatically from modular_deepseek_v3.py, by taking the corresponding code for LlamaXYZ, and then vanilla replace "Llama" by "DeepseekV3" everywhere (and I doubt that),

we still need to copy DeepseekV3RotaryEmbedding from modeling to modular, because that code genuinely has changed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One would at least have to copy the init and make sure that DeepseekV3DecoderLayer is used. But we also need to use DeepseekV3RotaryEmbedding, etc.

that is valid, this is what users should expect, we still need some tests to make sure that we raise an error when the layers being used are not changed as it should be just inheritance!

pass


class DeepseekV3RotaryEmbedding(LlamaRotaryEmbedding):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to copy the changed code over from modeling_deepseek_v3.py, because the RoPE embeddings have changed from Llama. First, there are your changes with mscale, etc. Second, my changes to pass dim=config.qk_rope_head_dim.

@mseeger
Copy link

mseeger commented Feb 25, 2025

Is there something I can help with?

@ArthurZucker
Copy link
Collaborator

Sorry have been working on #36335 to give us the tools to run the model as it was very very slow just loading the full checkpoint!

@casper-hansen
Copy link

This is a smaller, trained model using the DeepSeek V3 architecture. In BF16, not FP8. Might be helpful :)
https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeepSeek V3 Support
8 participants