Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][kernel] tensor parallelism with bitsandbytes quantization #8434

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner,
vllm_runner,
example_prompts[:1],
model_name,
hf_model_kwargs,
vllm_tp_size=2)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
Expand All @@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):
hf_model_kwargs=None,
vllm_tp_size=1):

# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference

#Run with vLLM runner
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
tensor_parallel_size=vllm_tp_size,
enforce_eager=True,
gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

Expand All @@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

Expand Down
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,6 @@ def verify_with_parallel_config(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
Expand Down
21 changes: 16 additions & 5 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,11 @@ def weight_loader(self,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -899,8 +902,13 @@ def weight_loader(self,
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand All @@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
Expand Down
39 changes: 38 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -911,13 +913,42 @@ def _parse_quant_state(param_name: str,
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")

# weight partitions of different modules occur at
# different dimensions
if 'down_proj' in weight_name or 'o_proj' in weight_name:
chenqianfzh marked this conversation as resolved.
Show resolved Hide resolved
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]

else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]

# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()

# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()

with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
Expand Down Expand Up @@ -958,6 +989,12 @@ def _load_weights(self, model_config: ModelConfig,
f"BitsAndBytes loader does not support {quant_method} "
"quantization")

# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with TP is not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a big issue still as most people would be using prequantized models afaik

Copy link
Contributor Author

@chenqianfzh chenqianfzh Sep 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bnb, quantization turns a tensor into a tensor and a struct of quant_state. bnb does not allow de-quantization of a full-size quant_state working with a split quantized tensor. That makes me unable to implement TP directly on pre-quant models using the weights and quant_states out of box.

I have been considering three solutions:

  1. request the bnb community to support the above scenario. I am not sure how long it might take. :-(
  2. unpack the prequant models when loading and then treat it as non-quantized models. But the model loading will be terribly slow, which invalidates the purpose of pre-quantization.
  3. suggest the user to turn to PP, as most popular models support PP now in vllm. I will update the error message.

@mgoin , what is your opinion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is unfortunate. I think you have chosen fairly in this case and the recommendation for PP in the error (#3) would be the best+simple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.


load_8bit = False
if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False)
Expand Down
Loading