Skip to content

Commit

Permalink
Expose FakeQuantizeConfigs in QAT quantizers (#1214)
Browse files Browse the repository at this point in the history
Summary: This commit exposes the activation and weight
FakeQuantizeConfigs in the existing QAT quantizers. These are
helpful for implementing advanced functionality based on the
quantization schemes represented by these quantizers, such as
composing QAT + LoRA.

Test Plan:
python test/quantization/test_qat.py
  • Loading branch information
andrewor14 authored Nov 4, 2024
1 parent 1fbf788 commit 88d604f
Showing 1 changed file with 74 additions and 27 deletions.
101 changes: 74 additions & 27 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, w)


class _LegacyQATQuantizer(TwoStepQuantizer):
"""
Base class for sharing common methods across legacy QAT quantizers.
"""
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return None

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return None


# =========================================================
# | Linear int8 dynamic activations + int4 weight QAT |
# =========================================================


class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have int8
dynamic per token fake quantized activations and int4 fake quantized
Expand Down Expand Up @@ -189,6 +200,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
else:
self._convert_qat_linear_8da4w(child)

def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return _get_8da4w_activation_config(self.scales_precision)

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return _get_8da4w_weight_config(self.groupsize, self.scales_precision)


class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
"""
Expand All @@ -211,22 +228,8 @@ def __init__(
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
activation_config = FakeQuantizeConfig(
dtype=torch.int8,
granularity="per_token",
is_symmetric=False,
is_dynamic=True,
scale_precision=scales_precision,
zero_point_precision=scales_precision,
)
weight_config = FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=groupsize,
is_symmetric=True,
is_dynamic=True,
scale_precision=scales_precision,
zero_point_precision=scales_precision,
)
activation_config = _get_8da4w_activation_config(scales_precision)
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
super().__init__(
in_features,
out_features,
Expand Down Expand Up @@ -261,12 +264,43 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
mod.disable_fake_quant()


def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig:
"""
Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
"""
return FakeQuantizeConfig(
dtype=torch.int8,
granularity="per_token",
is_symmetric=False,
is_dynamic=True,
scale_precision=qparams_precision,
zero_point_precision=qparams_precision,
)


def _get_8da4w_weight_config(
group_size: int,
qparams_precision: torch.dtype,
) -> FakeQuantizeConfig:
"""
Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
"""
return FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=group_size,
is_symmetric=True,
is_dynamic=True,
scale_precision=qparams_precision,
zero_point_precision=qparams_precision,
)


# ===================================
# | Linear int4 weight-only QAT |
# ===================================


class Int4WeightOnlyQATQuantizer(TwoStepQuantizer):
class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have
int4 fake quantized grouped per channel weights.
Expand Down Expand Up @@ -348,6 +382,9 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
else:
self._convert_qat_linear_4w(child)

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return _get_4w_weight_config(self.groupsize, self.scales_precision)


class Int4WeightOnlyQATLinear(FakeQuantizedLinear):
"""
Expand Down Expand Up @@ -376,15 +413,7 @@ def __init__(
if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles):
raise ValueError("Padding for QAT 4w is not supported yet")
self.inner_k_tiles = inner_k_tiles
weight_config = FakeQuantizeConfig(
dtype=torch.uint4,
group_size=groupsize,
is_symmetric=False,
is_dynamic=True,
scale_precision=scales_precision,
zero_point_precision=scales_precision,
zero_point_domain=ZeroPointDomain.FLOAT,
)
weight_config = _get_4w_weight_config(groupsize, scales_precision)
super().__init__(
in_features,
out_features,
Expand Down Expand Up @@ -417,3 +446,21 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
"""
if isinstance(mod, Int4WeightOnlyQATLinear):
mod.disable_fake_quant()


def _get_4w_weight_config(
group_size: int,
qparams_precision: torch.dtype,
) -> FakeQuantizeConfig:
"""
Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`.
"""
return FakeQuantizeConfig(
dtype=torch.uint4,
group_size=group_size,
is_symmetric=False,
is_dynamic=True,
scale_precision=qparams_precision,
zero_point_precision=qparams_precision,
zero_point_domain=ZeroPointDomain.FLOAT,
)

0 comments on commit 88d604f

Please sign in to comment.