Skip to content

Commit

Permalink
[Misc] Update compressed tensors lifecycle to remove prefix from `c…
Browse files Browse the repository at this point in the history
…reate_weights` (vllm-project#7825)
  • Loading branch information
dsikka authored and triple-Mu committed Sep 4, 2024
1 parent 0a2e320 commit 92c0d5c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 75 deletions.
9 changes: 3 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def __init__(self,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader,
prefix=prefix)
weight_loader=self.weight_loader)

if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -307,8 +306,7 @@ def __init__(self,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
Expand Down Expand Up @@ -976,8 +974,7 @@ def __init__(self,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import torch
from pydantic import BaseModel

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
Expand Down Expand Up @@ -52,15 +52,20 @@ def get_min_capability(cls) -> int:
def get_name(self) -> str:
return "compressed_tensors"

# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
Expand Down Expand Up @@ -281,15 +286,11 @@ def get_scheme(
to select the CompressedTensorsScheme used for infernece.
"""

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()

# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
Expand Down Expand Up @@ -327,10 +328,7 @@ def create_weights(self, layer: torch.nn.Module,
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")

scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights(
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
Expand All @@ -339,8 +337,6 @@ def create_weights(self, layer: torch.nn.Module,
params_dtype=params_dtype,
weight_loader=weight_loader)

layer.scheme = scheme

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
Expand All @@ -10,7 +9,6 @@

__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
Expand Down

This file was deleted.

0 comments on commit 92c0d5c

Please sign in to comment.