Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add offload for 8-bit model #1699

Merged
merged 7 commits into from
Jul 11, 2023
Merged

Add offload for 8-bit model #1699

merged 7 commits into from
Jul 11, 2023

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jul 10, 2023

What does this PR do ?

This PR makes offload on cpu/disk possible with 8-bit models, thus saving even more memory. Previously, we did not quantize the modules on cpu/disk and the modules weights stayed at full precision. With cpu/disk offlaod, we offload the quantized weight to cpu/disk and move them back to gpu when needed using hooks. This should work out of the box with device_map="auto" but we make the user specify enable_offload=True to be sure that he knows what he's doing. Furthermore, no modification is needed on bitsandbytes library.

The input weights (weights_location) can be quantized or not. If the weights are not quantized, we will first quantize them before offloading them to the cpu/disk. If we don't want to quantize a module, the user should add it in skip_modules arg.

PS: 4-bit model offload will be added when we will be able to serialize them.

input_text =  "Hello my name is"
tokenizer =  AutoTokenizer.from_pretrained("bigscience/bloom-1b7")
encoded_input = tokenizer(input_text, return_tensors="pt")

model_name = "marcsun13/bloom-1b7_with_lm_head"
weights_location = hf_hub_download(model_name, "pytorch_model.bin")

with init_empty_weights():
    model_8bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name))
model_8bit.tie_weights()

device_map = {'transformer.word_embeddings': 'cpu',
              'transformer.word_embeddings_layernorm': 'cpu',
              'transformer.h.0': 0,
              'transformer.h.1': 0,
              'transformer.h.2': 0,
              'transformer.h.3': 0,
              'transformer.h.4': 0,
              'transformer.h.5': 0,
              'transformer.h.6': 0,
              'transformer.h.7': 0,
              'transformer.h.8': 0,
              'transformer.h.9': 0,
              'transformer.h.10': 0,
              'transformer.h.11': 0,
              'transformer.h.12': 0, 
              'transformer.h.13': 'cpu', 
              'transformer.h.14': 'cpu',
              'transformer.h.15': 'cpu', 
              'transformer.h.16': 'cpu',
              'transformer.h.17': 'cpu',
              'transformer.h.18': 'cpu',
              'transformer.h.19': 'cpu',
              'transformer.h.20': 'cpu', 
              'transformer.h.21': 'cpu', 
              'transformer.h.22': 'cpu',
              'transformer.h.23': 'disk', 
              'transformer.ln_f': 'cpu',
              'lm_head':"cpu"}

bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, enable_offload=True)

model_8bit = load_and_quantize_model(model_8bit,
                            bnb_quantization_config,
                            weights_location=weights_location,
                            device_map = device_map,
                            no_split_module_classes=["BloomBlock"],
                            offload_state_dict=True,
                            offload_folder="tmp"
                            )
output_parallel = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
output_text = tokenizer.decode(output_parallel[0], skip_special_tokens=True)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 10, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great on my side!
As a small comment I would maybe it is worth it to make it clear to users on the relevant documentation page that the computation will be still done on the GPU to avoid any confusion. Also to be on the safe zone, can you try to run the transformers slow tests of bnb integration and make sure they pass?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Quick question on my side, why do we need the user to flag enabled_offload=True in their config file? They are already indicating their intent to offload weights with the device_map so this is asking the same thing twice. Is there any downside to remove that flag?

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
SunMarc and others added 4 commits July 11, 2023 09:09
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@SunMarc
Copy link
Member Author

SunMarc commented Jul 11, 2023

Thanks for working on this. Quick question on my side, why do we need the user to flag enabled_offload=True in their config file? They are already indicating their intent to offload weights with the device_map so this is asking the same thing twice. Is there any downside to remove that flag?

No there should not be any downside to remove that flag. Just removed it. It was just something that was used in the transformers integration for 8-bit model so I kept it initially.
cc @younesbelkada

docs/source/usage_guides/quantization.md Outdated Show resolved Hide resolved
@SunMarc
Copy link
Member Author

SunMarc commented Jul 11, 2023

Looks great on my side! As a small comment I would maybe it is worth it to make it clear to users on the relevant documentation page that the computation will be still done on the GPU to avoid any confusion. Also to be on the safe zone, can you try to run the transformers slow tests of bnb integration and make sure they pass?

Added a section in the doc for offload and the transformers slow tests passed (61 in total)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool work! Thanks for confirming that the tests pass on transformers

@SunMarc SunMarc merged commit 27d2908 into huggingface:main Jul 11, 2023
24 checks passed
@SunMarc SunMarc deleted the offload_8_bit branch July 11, 2023 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants