Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 support of Qwen1.5MoE models (A2.7B) to DeepSpeed-FastGen #5403

Merged
merged 6 commits into from
Aug 1, 2024

Conversation

ZonePG
Copy link
Contributor

@ZonePG ZonePG commented Apr 12, 2024

This PR adds support for Qwen1.5MoE-A2.7B models.

support for microsoft/DeepSpeed-MII#457

Test Code

for mii pipeline:

import mii

pipe = mii.pipeline("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
responses = pipe("DeepSpeed is", max_new_tokens=128, do_sample=False)
if pipe.is_rank_0:
    print(responses[0])

for huggingface:

import mii

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
tokenizer = AutoTokenizer.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
model = AutoModelForCausalLM.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).eval()
print(model)
inputs = tokenizer('DeepSpeed is', return_tensors='pt')
inputs = inputs.to(model.device)
pred = model.generate(**inputs, max_new_tokens=128, do_sample=False, repetition_penalty=1.0)
test = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
print(test)

Qwen1.5-MoE-A2.7B

Huggingface output with prompt "DeepSpeed is":

 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the

DeepSpeed-FastGen output with prompt "DeepSpeed is":

 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the

DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding:

 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the

shared_expert_output = self.shared_expert_mlp_2(shared_expert_output, cur_params.shared_moe_mlp_2, b=None)
shared_expert_gate_output = self.shared_expert_gate(hidden_states, cur_params.shared_moe_gate, b=None)[..., :1]
# shared_expert_gate_output shape[-1] is 1
shared_expert_output.mul_(torch.sigmoid(shared_expert_gate_output))
Copy link
Contributor Author

@ZonePG ZonePG Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if using torch.sigmoid directly will affect performance?

@heiseon
Copy link

heiseon commented Jul 5, 2024

When I use your source code to build DeepSpeed,and run “for mii pipeline” code, the process is blocked and no error。How should I identify the problem? I use 4090 GPU and transformer is 4.41.0.dev0, torch is 2.2.1, cuda version is 11.8.
BTW the transformer code runs well , just runs very slow.

@ZonePG
Copy link
Contributor Author

ZonePG commented Jul 5, 2024

Hi, @heiseon I just created a new conda environment and built it from my deepspeed code and deepspeed-mii offical source code, and it’s ok without any issues.

maybe you can delete ~/.cache/torch_extensions/pyxxx_cuxxx and try it again.

my path is /data/zonepg/.cache/torch_extensions/py311_cu121.

@heiseon
Copy link

heiseon commented Jul 10, 2024

Hi, @heiseon I just created a new conda environment and built it from my deepspeed code and deepspeed-mii offical source code, and it’s ok without any issues.

maybe you can delete ~/.cache/torch_extensions/pyxxx_cuxxx and try it again.

my path is /data/zonepg/.cache/torch_extensions/py311_cu121.

delete ~/.cache/torch_extensions/pyxxx_cuxxx is worked for me.
I have another question. When using the Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4 quantized version, an error occurs: 'Could not find a mapping for dependency "mlp.experts.18.gate_proj.bias"'. Does it mean that the GPTQ quantized version of the model is not supported?"
Would you like assistance with addressing this issue or understanding more about GPTQ quantization compatibility?

@ZonePG
Copy link
Contributor Author

ZonePG commented Jul 10, 2024

Hi, @heiseon It does not support quantized Qwen models currently. Supporting this would likely require a big effort, so it might not be considered in the short term.

@HeyangQin HeyangQin enabled auto-merge July 15, 2024 23:43
@loadams loadams disabled auto-merge July 16, 2024 17:26
@loadams loadams requested review from lekurile and removed request for mrwyattii July 16, 2024 17:27
@loadams loadams added this pull request to the merge queue Jul 16, 2024
xslingcn added a commit to xslingcn/DeepSpeed that referenced this pull request Jul 17, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jul 18, 2024
@loadams loadams merged commit 249c1db into microsoft:master Aug 1, 2024
7 checks passed
github-merge-queue bot pushed a commit that referenced this pull request Aug 22, 2024
based on PR #5403 (Qwen1.5-MOE) and #5219 (Qwen1.5), support Qwen2
series model.

including: 0.5B, 1.5B, 7B, 57B-A14B, and 72B models.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants