Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support bitsandbytes 8-bit and FP4 quantized models #7445

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,14 @@ class HfRunner:

def wrap_device(self, input: _T) -> _T:
if not is_cpu():
# Check if the input is already on the GPU
if hasattr(input, 'device') and input.device.type == "cuda":
return input # Already on GPU, no need to move
return input.to("cuda")
else:
# Check if the input is already on the CPU
if hasattr(input, 'device') and input.device.type == "cpu":
return input # Already on CPU, no need to move
return input.to("cpu")

def __init__(
Expand Down
166 changes: 98 additions & 68 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,115 @@

Run `pytest tests/quantization/test_bitsandbytes.py`.
'''

import gc

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams

models_to_test = [
models_4bit_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]

models_pre_qaunt_4bit_to_test = [
('lllyasviel/omost-llama-3-8b-4bits',
'read pre-quantized 4-bit NF4 model'),
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
]

models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
]


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, hf_model_kwargs)


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
models_pre_qaunt_4bit_to_test)
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_to_test)
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
@pytest.mark.parametrize("model_name, description",
models_pre_quant_8bit_to_test)
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
log_entry = {
"prompt": prompts[i],
"runner_name": runner_name,
"generated_text": generated_text,
}
logged_texts.append(log_entry)
return logged_texts


def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):

if hf_model_kwargs is None:
hf_model_kwargs = {}

# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

#Run with vLLM runner
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501

# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].mlp.down_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].self_attn.o_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].self_attn.qkv_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')

# some weights should not be quantized
weight = model.lm_head.weight
assert weight.dtype != torch.uint8, (
'lm_head weight dtype should not be torch.uint8')

weight = model.model.embed_tokens.weight
assert weight.dtype != torch.uint8, (
'embed_tokens weight dtype should not be torch.uint8')

weight = model.model.layers[0].input_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')

weight = model.model.layers[0].post_attention_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')

# check the output of the model is expected
sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)

prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts)

for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]

assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]

assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
enforce_eager=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bitsandbytes-foundation/bitsandbytes#1330 has been merged. Regarding these tests, can we now use cudagraph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @jeejeelee, thanks to your fix in bnb.

But the latest version of bnb package was released three weeks ago, which does not include your fix yet.

I will update the code after the next bnb release is out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think it is worth doing the package upgrade in another PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this PR introduces a new BNB kernel . I'm not sure if the previous modifications to BNB can support this kernel. What I mean is, perhaps we should first verify this (build BNB from source). If this kernel still not supported, we may need to continue refining the relevant BNB code.

If you're not available, I can verify it next week.

Copy link
Contributor Author

@chenqianfzh chenqianfzh Aug 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeejeelee
I tried the new bnb kernel with your fix to run the above tests under graph mode (and some other tests), it worked perfectly! :-)

gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

# Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"
f"HF Output: '{hf_str}'\n"
f"vLLM Output: '{vllm_str}'")
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ def verify_with_parallel_config(
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
"fallback to the eager mode.")
Expand Down
18 changes: 10 additions & 8 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


def adjust_bitsandbytes_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

total, _ = qkv_offsets["total"]
Expand Down Expand Up @@ -506,8 +506,9 @@ def weight_loader(self,
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
Expand Down Expand Up @@ -859,8 +860,9 @@ def weight_loader(self,
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
Expand All @@ -872,7 +874,7 @@ def weight_loader(self,
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)

if is_gguf_weight:
Expand Down
Loading
Loading