Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Refactor linear layer weight loading; introduce BasevLLMParameter and weight_loader_v2 #5874

Merged
merged 22 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):

assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == pack_factor
assert qkv_proj.scheme.pack_factor == pack_factor

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
Expand Down Expand Up @@ -140,13 +140,17 @@ def test_compressed_tensors_fp8(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))

assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0

if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed

__all__ = [
"SamplingMetadata",
"set_random_seed",
"BasevLLMParameter",
"PackedvLLMParameter",
]
153 changes: 151 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]


def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
Expand Down Expand Up @@ -288,14 +292,17 @@ def __init__(self,

if output_sizes is None:
output_sizes = [output_size]

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -326,6 +333,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
param.load_column_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None

Expand Down Expand Up @@ -488,6 +498,62 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.

An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""

current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size

for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)

loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
if loaded_shard_id is None:
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return

assert loaded_shard_id < len(self.output_sizes)

tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size

param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)


class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Expand Down Expand Up @@ -559,6 +625,82 @@ def __init__(self,
quant_config=quant_config,
prefix=prefix)

def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
}
return shard_offset_mapping.get(loaded_shard_id)

def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)

def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.

An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
dsikka marked this conversation as resolved.
Show resolved Hide resolved
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]

for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
dsikka marked this conversation as resolved.
Show resolved Hide resolved
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)

loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
if loaded_shard_id is None: # special case for certain models
dsikka marked this conversation as resolved.
Show resolved Hide resolved
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return

assert loaded_shard_id in ["q", "k", "v"]

shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)

param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)

def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
Expand Down Expand Up @@ -729,14 +871,17 @@ def __init__(self,
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
Expand Down Expand Up @@ -770,6 +915,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
param.load_row_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
if self.input_is_parallel:
input_parallel = input_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform

__all__ = ["CompressedTensorsLinearMethod"]


class CompressedTensorsConfig(QuantizationConfig):

Expand Down Expand Up @@ -146,18 +148,15 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
if weight_quant is None or input_quant is None:
return False

# Confirm we have floating points.
if not (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT):
return False

# Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False

Expand All @@ -169,11 +168,7 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_activation and is_per_tensor_activation):
return False

# All conditions satisfied.
return True
return is_symmetric_activation and is_per_tensor_activation

def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -230,14 +225,16 @@ def _get_scheme_from_parts(
group_size=weight_quant.group_size)

# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. General follow-up on the state of these conditions

if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
is_static_input_scheme=(input_quant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changing? Won't be always have input_quant if is_activation_quantization_format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an extra check the activation config details aren't None/parsed correctly.

and not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch
import torch.nn.functional as F
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import ModelWeightParameter

__all__ = ["CompressedTensorsUnquantized"]

Expand All @@ -24,22 +23,25 @@ def get_min_capability(cls) -> int:
return 70

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
# required by torch.compile to be torch.nn.Parameter
layer.weight = torch.nn.Parameter(layer.weight.data,
dsikka marked this conversation as resolved.
Show resolved Hide resolved
requires_grad=False)

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
Expand Down
Loading
Loading