Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 7, 2024
1 parent 9696a55 commit 11a9dec
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.scalar_type import scalar_types

__all__ = ["CompressedTensorsWNA16"]
Expand Down Expand Up @@ -54,7 +53,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

output_size_per_partition = sum(output_partition_sizes)

mp_linear_kernel_config = MPLinearLayerConfig(
Expand Down Expand Up @@ -136,7 +135,6 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)


def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
15 changes: 4 additions & 11 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,16 @@
import torch
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger

from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)

from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx,
verify_marlin_supported, verify_marlin_supports_shape)
check_marlin_supported, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types

Expand Down Expand Up @@ -289,8 +284,6 @@ def create_weights(
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device

# `qweight` and `scales` are already in the correct format. So we can
# just call `process_weights_after_loading` right-away
self.kernel.process_weights_after_loading(layer)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Callable
from typing import Callable, Optional, Tuple

import torch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
query_machete_supported_quant_types)
from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)

from .MPLinearKernel import *


Expand Down Expand Up @@ -43,20 +44,22 @@ def can_implement(cls,
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):

def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.machete_prepack_B(x.t().contiguous().t(),
self.config.weight_type)

def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return x.contiguous()

# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_make_empty_g_idx,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
marlin_is_k_full, query_marlin_supported_quant_types)

from .MPLinearKernel import *
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)

from .MPLinearKernel import *


class MarlinLinearKernel(MPLinearKernel):

@classmethod
Expand Down Expand Up @@ -82,27 +83,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))

def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.gptq_marlin_repack(
x.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)

return ops.gptq_marlin_repack(x.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)

def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return marlin_permute_scales(
x.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
return marlin_permute_scales(x.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)

self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union

import torch


Expand Down
28 changes: 14 additions & 14 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,25 @@ class ModelWeightParameter(_ColumnvLLMParameter):
def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int, **kwargs) \
-> 'ModelWeightParameter':

# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim
# preservier other dimensions
perm = [i for i in range(self.data.dim())
if i not in [self.input_dim, self.output_dim]
perm = [
i for i in range(self.data.dim())
if i not in [self.input_dim, self.output_dim]
]
perm.insert(input_dim, self.input_dim)
perm.insert(output_dim, self.output_dim)

return ModelWeightParameter(
data=self.data.permute(*perm).contiguous(),
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
**kwargs)

return ModelWeightParameter(data=self.data.permute(*perm).contiguous(),
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
**kwargs)

@property
def input_dim(self):
Expand Down Expand Up @@ -272,18 +272,18 @@ def __init__(self,
self._marlin_tile = marlin_tile_size
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int,
def permute_layout(self, input_dim: int, output_dim: int,
packed_dim: int = 0,
**kwargs)\
-> 'ModelWeightParameter':

assert packed_dim == packed_dim

return PackedvLLMParameter(
data=ModelWeightParameter\
.permute_layout(self, input_dim, output_dim).data,
weight_loader=self.weight_loader,
input_dim=input_dim,
input_dim=input_dim,
output_dim=output_dim,
packed_dim=self.packed_dim,
packed_factor=self.packed_factor,
Expand Down

0 comments on commit 11a9dec

Please sign in to comment.