diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3ff82c9b256bb..58ef5dd97c149 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -19,6 +19,8 @@ logger = init_logger(__name__) +WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"] + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) @@ -287,9 +289,6 @@ def __init__(self, if output_sizes is None: output_sizes = [output_size] - from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501; TODO: temporary fix; have to fix circular import - CompressedTensorsLinearMethod) - self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size, @@ -297,9 +296,9 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=(self.weight_loader_v2 if isinstance( - self.quant_method, CompressedTensorsLinearMethod) else - self.weight_loader), + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), prefix=prefix) if bias: self.bias = Parameter( @@ -868,8 +867,6 @@ def __init__(self, self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None - from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501; TODO: fix cir import - CompressedTensorsLinearMethod) self.quant_method.create_weights( layer=self, @@ -878,9 +875,9 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=(self.weight_loader_v2 if isinstance( - self.quant_method, CompressedTensorsLinearMethod) else - self.weight_loader), + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), prefix=prefix) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the "