Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use MimiModel with DeepSpeed ZeRO-3 #34735

Merged
merged 2 commits into from
Jan 17, 2025

Conversation

anferico
Copy link
Contributor

@anferico anferico commented Nov 14, 2024

What does this PR do?

Allow using MimiModel with DeepSpeed ZeRO-3.

Fixes the following error:

[rank0]: Traceback (most recent call last):                                                                                                                                                                                                  
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/./mimi_mre.py", line 10, in <module>                                                                                                                                         
[rank0]:     main()                                                                                                                                                                                                                          
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/./mimi_mre.py", line 7, in main                                                                                                                                              
[rank0]:     model = MimiModel.from_pretrained("kyutai/mimi", config=config)                                                                                                                                                                 
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/modeling_utils.py", line 4110, in from_pretrained                        
[rank0]:     model = cls(config, *model_args, **model_kwargs)                                                                                             
[rank0]:   File "/net/tscratch/people/plgfcariaggi/envs/transformers/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 511, in wrapper
[rank0]:     f(module, *args, **kwargs)                                                                                                                   
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/models/mimi/modeling_mimi.py", line 1515, in __init__                    
[rank0]:     self.quantizer = MimiSplitResidualVectorQuantizer(config)                                                                                    
[rank0]:   File "/net/tscratch/people/plgfcariaggi/envs/transformers/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 511, in wrapper
[rank0]:     f(module, *args, **kwargs)                                                                                                                               
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/models/mimi/modeling_mimi.py", line 1340, in __init__                    
[rank0]:     self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_semantic_quantizers)                                  
[rank0]:   File "/net/tscratch/people/plgfcariaggi/envs/transformers/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 511, in wrapper
[rank0]:     f(module, *args, **kwargs)                                                                                                                   
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/models/mimi/modeling_mimi.py", line 1282, in __init__                                
[rank0]:     self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)])                                            
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/models/mimi/modeling_mimi.py", line 1282, in <listcomp>                  
[rank0]:     self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)])                                            
[rank0]:   File "/net/tscratch/people/plgfcariaggi/envs/transformers/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 511, in wrapper
[rank0]:     f(module, *args, **kwargs)                                                          
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/models/mimi/modeling_mimi.py", line 1261, in __init__                                                            
[rank0]:     self.codebook = MimiEuclideanCodebook(config)                                       
[rank0]:   File "/net/tscratch/people/plgfcariaggi/envs/transformers/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 511, in wrapper                           
[rank0]:     f(module, *args, **kwargs)                                                          
[rank0]:   File "/net/tscratch/people/plgfcariaggi/transformers/src/transformers/models/mimi/modeling_mimi.py", line 1217, in __init__                                                            
[rank0]:     self.register_buffer("initialized", torch.Tensor([True]))                           
[rank0]:   File "/net/tscratch/people/plgfcariaggi/envs/transformers/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 255, in new_tensor
[rank0]:     tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs)                                                                                                                                                         
[rank0]: TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type bool at pos 0 

Explanation: to begin with, the use of torch.Tensor(...) is discouraged/deprecated (cf. https://pytorch.org/docs/stable/tensors.html#torch.Tensor) and should be dropped in favor of torch.tensor(...). Secondly, here is an excerpt of DeepSpeed's source code (deepspeed/runtime/zero/partition_parameters.py) that explains why we get this error:

def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:

    def wrapped_fn(*args, **kwargs) -> Tensor:
        if kwargs.get("device", None) is None:
            kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
        tensor: Tensor = fn(*args, **kwargs)
        if tensor.is_floating_point():
            tensor.data = tensor.data.to(target_fp_dtype)

        return tensor

    return wrapped_fn


def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:

    def new_tensor(cls, *args, **kwargs) -> Tensor:
        device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
        if not args:
            args = (0, )
        tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs)
        if tensor.is_floating_point():
            tensor = tensor.to(dtype)

        return tensor

    return new_tensor

[...]

    def _add_tensor_creation_wrappers(self):
        torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
        torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)
        torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
        torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
        torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
        torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
        torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
        torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)
        torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype)

as can be seen, DeepSpeed ZeRO-3 patches Tensor.__new__ with get_new_tensor_fn_for_dtype(self.dtype) which, if called with [True] as an argument, raises an error because such argument is passed directly to torch.Tensor.new_empty, which expects a size as its first argument. On the contrary, torch.tensor is patched with the correct function (zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)) and won't throw any error when passed [True] as an argument.

Reproducing the error

mimi_mre.sh:

OUTPUT_DIR=$HOME/mimi_mre

deepspeed \
    --num_gpus 1 \
    --master_port 60000 \
    ./mimi_mre.py \
    --output_dir $OUTPUT_DIR \
    --deepspeed zero3.json

mimi_mre.py:

from transformers import AutoConfig, MimiModel, TrainingArguments, HfArgumentParser

def main():
    parser = HfArgumentParser(TrainingArguments)
    training_args = parser.parse_args_into_dataclasses()[0]
    config = AutoConfig.from_pretrained("kyutai/mimi")
    model = MimiModel.from_pretrained("kyutai/mimi", config=config)

if __name__ == "__main__":
    main()

zero3.json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "train_micro_batch_size_per_gpu": "auto",
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ylacombe, @eustlb

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! thanks for opening this!

Let's push an empty commit to run slow tests before merging:
`git commit --allow-empty -m "[run-slow] mimi"

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@anferico
Copy link
Contributor Author

@ylacombe all done!

@anferico
Copy link
Contributor Author

@ylacombe 1 test is failing (tests.models.mimi.test_modeling_mimi.MimiIntegrationTest.test_integration), which is kinda weird cause I don't think my change has any effect on the model's behavior. Any thoughts?

@anferico
Copy link
Contributor Author

anferico commented Jan 2, 2025

@ylacombe ping

@anferico anferico requested a review from eustlb as a code owner January 13, 2025 10:36
@anferico
Copy link
Contributor Author

@eustlb ? @ylacombe ? Any idea why some tests are failing?

@eustlb
Copy link
Contributor

eustlb commented Jan 16, 2025

Hey @anferico, the failing slow test is unrelated indeed (see #35696).
I am of course okay with the provided fix 😊
Does it fix the originally reported issue? Can you please update the PR comment with a brief explanation if it does? 🙏
(from my understanding, it is a dtype issue entailed by torch.Tensor that fails in getting the right dtype - torch.bool -, contrary to torch.tensor, can you confirm?)

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thanks

@anferico
Copy link
Contributor Author

Hey @anferico, the failing slow test is unrelated indeed (see #35696). I am of course okay with the provided fix 😊 Does it fix the originally reported issue? Can you please update the PR comment with a brief explanation if it does? 🙏 (from my understanding, it is a dtype issue entailed by torch.Tensor that fails in getting the right dtype - torch.bool -, contrary to torch.tensor, can you confirm?)

@eustlb yes I can confirm it fixes the issue, meaning I can use DeepSpeed ZeRO-3 in conjunction with MimiModel with the changes introduced in this PR

@anferico
Copy link
Contributor Author

@eustlb just updated the PR description too!

@eustlb
Copy link
Contributor

eustlb commented Jan 17, 2025

Great! thanks a lot 🤗

@eustlb eustlb merged commit 54fd7e9 into huggingface:main Jan 17, 2025
9 of 10 checks passed
eustlb added a commit that referenced this pull request Jan 17, 2025
eustlb added a commit that referenced this pull request Jan 17, 2025
Revert "Unable to use `MimiModel` with DeepSpeed ZeRO-3 (#34735)"

This reverts commit 54fd7e9.
@eustlb
Copy link
Contributor

eustlb commented Jan 17, 2025

Reverted because it caused a failing test on Moshi that was not catched by CI, looking into it

@eustlb
Copy link
Contributor

eustlb commented Jan 17, 2025

Found the issue, we need to keep the registered buffer to be a torch.float32, sorry for missing that!
Would you like to open a new PR by simply adding this one as a reference and changing:

- self.register_buffer("initialized", torch.Tensor([True]))
+ self.register_buffer("initialized", torch.tensor([True], dtype=torch.float32))

?

@anferico
Copy link
Contributor Author

@eustlb done! #35759

bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
use torch.tensor(), not torch.Tensor()

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
…#35755)

Revert "Unable to use `MimiModel` with DeepSpeed ZeRO-3 (huggingface#34735)"

This reverts commit 54fd7e9.
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
use torch.tensor(), not torch.Tensor()

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
…#35755)

Revert "Unable to use `MimiModel` with DeepSpeed ZeRO-3 (huggingface#34735)"

This reverts commit 54fd7e9.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants