Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch.inference_mode() for lower memory usage during calibration #20

Merged
merged 1 commit into from
Jun 17, 2024

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Jun 17, 2024

On an H100 80B the calibration of a Llama 3 8B with a ~8192 sequence length input would cause OOM issues. With the small addition of with torch.inference_mode(): to the calibration loop, I see only a peak usage of ~15GB.

Snippet used for testing:

from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
seq_len = 8192
examples = ["hello " * seq_len]
examples = tokenizer(examples, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
    quant_method="fp8",
    activation_scheme="static",
    ignore_patterns=["re:.*lm_head"],
)

model = AutoFP8ForCausalLM.from_pretrained(
    pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)

@mgoin mgoin linked an issue Jun 17, 2024 that may be closed by this pull request
@mgoin mgoin merged commit b1c6ad6 into main Jun 17, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Memory requirements for long sequences
1 participant