Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Update fbgemmfp8 to use vLLMParameters #7972

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod"
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
]


Expand Down
34 changes: 21 additions & 13 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
apply_fp8_linear)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform

logger = init_logger(__name__)
Expand Down Expand Up @@ -85,6 +86,7 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)

Expand All @@ -95,20 +97,21 @@ def create_weights(
layer.orig_dtype = params_dtype

# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})

# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE UPPER BOUND
Expand All @@ -118,6 +121,11 @@ def create_weights(
layer.input_scale_ub = input_scale_ub

def process_weights_after_loading(self, layer: Module) -> None:
# required by torch.compile
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)

weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

Expand Down
27 changes: 0 additions & 27 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip

Expand Down Expand Up @@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))


def create_per_tensor_scale_param(
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> Parameter:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**extra_weight_attrs
})
return scale


def create_per_channel_scale_param(output_partition_sizes: List[int],
**extra_weight_attrs) -> Parameter:
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
return scale


def convert_to_channelwise(
weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
Loading