-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Unused params lead to "still have inflight params" error #4094
Comments
Have you solved the problem? My situation is exactly the same as yours. |
Hi @alexwangmac I haven't really solved this problem, just worked around it with setting Hoping deepspeed team can help with this soon. |
Same |
I ran into the same problem and your fix worked! |
Any update on this? Running into the same issue when I have unused parameters for a given forward pass! |
In the config json, set |
While this might "work" this still not solves the problem for example with |
I have exactly the same issue, when will Mixtral support be added to |
(I posted a similar comment on #4808) |
…oks (#4966) ZeRO3 does not work with MoE models because the order of executing modules can change at every forward/backward pass (#4094, #4808). This PR adds an API to stop breaking down a module for parameter fetching. The following shows an example of the usage: ```python import torch import deepspeed import deepspeed.comm as dist from transformers.deepspeed import HfDeepSpeedConfig from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock model_id = "mistralai/Mixtral-8x7B-v0.1" ds_config = { "bf16": { "enabled": True, }, "zero_optimization": { "stage": 3, }, "train_micro_batch_size_per_gpu": 1, } hfdsc = HfDeepSpeedConfig(ds_config) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) model.eval() ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] ds_engine.module.eval() model = ds_engine.module inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda") outputs = model.generate(inputs, max_new_tokens=200) output_str = tokenizer.decode(outputs[0]) if dist.get_rank() == 0: print(f"output: {output_str}") ``` By passing names of modules to `set_z3_leaf_modules`, DeepSpeed engine stops breaking down the module. In this example, `MixtralSparseMoeBlock` has multiple experts as its submodule. Using `set_z3_leaf_modules`, the DeepSpeed engine fetches parameters of all the submodules when pre-fetching the parameters of `MixtralSparseMoeBlock`.
Hi everyone, |
…oks (deepspeedai#4966) ZeRO3 does not work with MoE models because the order of executing modules can change at every forward/backward pass (deepspeedai#4094, deepspeedai#4808). This PR adds an API to stop breaking down a module for parameter fetching. The following shows an example of the usage: ```python import torch import deepspeed import deepspeed.comm as dist from transformers.deepspeed import HfDeepSpeedConfig from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock model_id = "mistralai/Mixtral-8x7B-v0.1" ds_config = { "bf16": { "enabled": True, }, "zero_optimization": { "stage": 3, }, "train_micro_batch_size_per_gpu": 1, } hfdsc = HfDeepSpeedConfig(ds_config) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) model.eval() ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] ds_engine.module.eval() model = ds_engine.module inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda") outputs = model.generate(inputs, max_new_tokens=200) output_str = tokenizer.decode(outputs[0]) if dist.get_rank() == 0: print(f"output: {output_str}") ``` By passing names of modules to `set_z3_leaf_modules`, DeepSpeed engine stops breaking down the module. In this example, `MixtralSparseMoeBlock` has multiple experts as its submodule. Using `set_z3_leaf_modules`, the DeepSpeed engine fetches parameters of all the submodules when pre-fetching the parameters of `MixtralSparseMoeBlock`.
Hi, I also found this problem also in my experiments. It seems in generation some parameters are not used. |
Bug description
Context: Running inference on a multi-modal LLM , at each decoding step parts of the network are used and depends on the input modality at each step. In my second step, deepspeed goes ahead and fetches part of the network that ends up not being used. The code does assume that this can happen and correctly invalidates the trace. However, for the params that were prefetched but never used, at the end of the step, these are detected as in-flight and result in the
RuntimeError(f"still have inflight params").
To Reproduce
My setup is a bit involved. I am thinking it is clear from the description what the issue is. However, if the team feels like they can benefit from a simple reproduction, I can work on creating one. Please let me know.
Expected behavior
I would have expected that when we notice the order of params isn't the same as before, it would be reasonable to also not demand that all the parameters be used. Right now, we tolerate different ordering but require that all the params previously used (hence prefetched) need to be used at some point.
ds_report output
System info (please complete the following information):
AL2 (Amazon Linux) 5.10.149-133.644.amzn2.x86_64 #1 SMP Tue Oct 18 16:52:42 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
p3.16xlarge instance from aws, 8 V100 with 16 GB per device
0.10.0
4.29.1
accelerate0.21.0
3.9.15
The text was updated successfully, but these errors were encountered: