Skip to content

Commit

Permalink
[distributed][kernel]support tensor-parallelism in bitsandbytes quant…
Browse files Browse the repository at this point in the history
…ization
  • Loading branch information
chenqianfzh committed Jun 25, 2024
1 parent ba991d5 commit e8d5453
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 12 deletions.
33 changes: 33 additions & 0 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,36 @@ def test_load_bnb_model(vllm_runner) -> None:
expected_output = expected_outputs[index].split('\n', 1)[0]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')


@pytest.mark.skipif(
not is_quant_method_supported("bitsandbytes")
or torch.cuda.device_count() < 2,
reason='This test requires bitsandbytes support and at least 2 GPUs.')
def test_tp_load_bnb_model(vllm_runner) -> None:
with vllm_runner('huggyllama/llama-7b',
quantization='bitsandbytes',
load_format='bitsandbytes',
tensor_parallel_size=2,
enforce_eager=True) as llm:

sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)

prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)

assert len(outputs) == len(prompts)

for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,6 @@ def verify_with_parallel_config(
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
Expand Down
22 changes: 17 additions & 5 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def apply(self,
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)

return F.linear(x, weight, bias)


Expand Down Expand Up @@ -440,8 +441,12 @@ def weight_loader(self,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -652,8 +657,12 @@ def weight_loader(self,
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -757,8 +766,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):

tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and use_bitsandbytes is False:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
Expand Down
33 changes: 32 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,15 @@ def _get_quantized_weights_iterator(
"`pip install bitsandbytes>=0.42.0` to use "
"bitsandbytes quantizer.") from err

from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)

hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)

tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
Expand All @@ -711,11 +717,36 @@ def _get_quantized_weights_iterator(

def generator():
for weight_name, weight_tensor in weight_iterator:

if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")

# weight partitions of different modules occur at
# different dimensions
if 'down_proj' in weight_name or 'o_proj' in weight_name:
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[
..., start_index:end_index]

else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[
start_index:end_index, ...]

# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# bitsandbytes requires a contiguous tensor
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()

with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
Expand Down

0 comments on commit e8d5453

Please sign in to comment.