Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] refactor model loading #10013

Open
sayakpaul opened this issue Nov 25, 2024 · 1 comment
Open

[Core] refactor model loading #10013

sayakpaul opened this issue Nov 25, 2024 · 1 comment
Assignees
Labels
refactor roadmap Add to current release roadmap

Comments

@sayakpaul
Copy link
Member

sayakpaul commented Nov 25, 2024

Currently, we have got two codepaths:

  1. For non-sharded checkpoints we do:
    unexpected_keys = load_model_dict_into_meta(
  2. For sharded checkpoints we do:
    accelerate.load_checkpoint_and_dispatch(

And then for the (bnb) quantized checkpoints, we merge a sharded checkpoint:

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)

Essentially, we shouldn't have to merge sharded checkpoints even if it's quantized.

This will also allow us to more generally use keep_module_in_fp32 for sharded checkpoints. Currently, we have this logic for casting a model (which is tested thoroughly):

elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:

When using load_model_dict_into_meta(), we do consider keep_module_in_fp32:

keep_in_fp32_modules=None,

But since for sharded checkpoints, we use load_checkpoint_and_dispatch(), there is no way to pass keep_module_in_fp32:
https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.load_checkpoint_and_dispatch

As discussed with @SunMarc, it's better to uniformize this so that we don't have to maintain two different codepaths and rely completely on load_model_dict_into_meta(). Marc has kindly agreed to open a PR to attempt this (this could be done in a series of PRs if needed). But I will join if any help is needed.

@sayakpaul
Copy link
Member Author

@huggingface/diffusers Marc has started working on this 🥳

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor roadmap Add to current release roadmap
Projects
Development

No branches or pull requests

2 participants