Skip to content

Commit

Permalink
add list for v2 supported weight loading
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jul 23, 2024
1 parent 4a59faa commit 5cbd8b6
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

logger = init_logger(__name__)

WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]


def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
Expand Down Expand Up @@ -287,19 +289,16 @@ def __init__(self,
if output_sizes is None:
output_sizes = [output_size]

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501; TODO: temporary fix; have to fix circular import
CompressedTensorsLinearMethod)

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(self.weight_loader_v2 if isinstance(
self.quant_method, CompressedTensorsLinearMethod) else
self.weight_loader),
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -868,8 +867,6 @@ def __init__(self,
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501; TODO: fix cir import
CompressedTensorsLinearMethod)

self.quant_method.create_weights(
layer=self,
Expand All @@ -878,9 +875,9 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(self.weight_loader_v2 if isinstance(
self.quant_method, CompressedTensorsLinearMethod) else
self.weight_loader),
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
Expand Down

0 comments on commit 5cbd8b6

Please sign in to comment.