Skip to content

Commit

Permalink
[Misc] Refactor linear layer weight loading; introduce `BasevLLMParam…
Browse files Browse the repository at this point in the history
…eter` and `weight_loader_v2` (vllm-project#5874)
  • Loading branch information
dsikka authored and fialhocoelho committed Aug 22, 2024
1 parent df3bd12 commit 9684235
Show file tree
Hide file tree
Showing 11 changed files with 655 additions and 203 deletions.
20 changes: 12 additions & 8 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):

assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == pack_factor
assert qkv_proj.scheme.pack_factor == pack_factor

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
Expand Down Expand Up @@ -140,13 +140,17 @@ def test_compressed_tensors_fp8(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))

assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0

if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed

__all__ = [
"SamplingMetadata",
"set_random_seed",
"BasevLLMParameter",
"PackedvLLMParameter",
]
153 changes: 151 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]


def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
Expand Down Expand Up @@ -288,14 +292,17 @@ def __init__(self,

if output_sizes is None:
output_sizes = [output_size]

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -337,6 +344,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
param.load_column_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None

Expand Down Expand Up @@ -527,6 +537,62 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""

current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size

for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)

loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
if loaded_shard_id is None:
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return

assert loaded_shard_id < len(self.output_sizes)

tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size

param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)


class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Expand Down Expand Up @@ -598,6 +664,82 @@ def __init__(self,
quant_config=quant_config,
prefix=prefix)

def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
}
return shard_offset_mapping.get(loaded_shard_id)

def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)

def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]

for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)

loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
if loaded_shard_id is None: # special case for certain models
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return

assert loaded_shard_id in ["q", "k", "v"]

shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)

param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)

def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
Expand Down Expand Up @@ -798,14 +940,17 @@ def __init__(self,
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
Expand Down Expand Up @@ -850,6 +995,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
param.load_row_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
if self.input_is_parallel:
input_parallel = input_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform

__all__ = ["CompressedTensorsLinearMethod"]


class CompressedTensorsConfig(QuantizationConfig):

Expand Down Expand Up @@ -146,18 +148,15 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
if weight_quant is None or input_quant is None:
return False

# Confirm we have floating points.
if not (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT):
return False

# Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False

Expand All @@ -169,11 +168,7 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_activation and is_per_tensor_activation):
return False

# All conditions satisfied.
return True
return is_symmetric_activation and is_per_tensor_activation

def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -230,14 +225,16 @@ def _get_scheme_from_parts(
group_size=weight_quant.group_size)

# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch
import torch.nn.functional as F
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import ModelWeightParameter

__all__ = ["CompressedTensorsUnquantized"]

Expand All @@ -24,22 +23,25 @@ def get_min_capability(cls) -> int:
return 70

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
# required by torch.compile to be torch.nn.Parameter
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 9684235

Please sign in to comment.