Skip to content

Commit

Permalink
Revert "Add layout option to woq int4 api (#670)"
Browse files Browse the repository at this point in the history
This reverts commit 009f55f.
  • Loading branch information
msaroufim authored Aug 14, 2024
1 parent 009f55f commit 8b7b538
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
9 changes: 3 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -853,20 +852,18 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}

def api(mod):
kwargs_copy = kwargs.copy()
if TORCH_VERSION_AFTER_2_4:
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy))
if not TORCH_VERSION_AFTER_2_5:
unwrap_tensor_subclass(mod)
else:
kwargs_copy["inner_k_tiles"] = inner_k_tiles
del kwargs_copy["layout_type"]
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
change_linear_weights_to_int4_woqtensors(mod, **kwargs)

self._test_lin_weight_subclass_api_impl(
api,
Expand Down
37 changes: 27 additions & 10 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,7 @@
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional

from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import (
to_affine_quantized,
TensorCoreTiledLayoutType,
PlainLayoutType,
AffineQuantizedTensor,
SemiSparseLayoutType
)
from torchao.dtypes import PlainLayoutType
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -189,6 +182,9 @@ def _replace_with_custom_fn_if_matches_filter(


def _is_linear(mod, *args):
# avoid circular dep
from torchao.dtypes import AffineQuantizedTensor

# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
# when it is shared by multiple linear modules
return (
Expand Down Expand Up @@ -332,6 +328,9 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
Expand All @@ -340,6 +339,9 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
Expand Down Expand Up @@ -371,7 +373,7 @@ def insert_subclass(lin):
return insert_subclass


def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)):
def int4_weight_only(group_size=128, inner_k_tiles=8):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
"tensor_core_tiled" layout for speedup with tinygemm kernel
Expand All @@ -387,12 +389,16 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
"""
def apply_int4_weight_only_quant(weight):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
from torchao.dtypes import TensorCoreTiledLayoutType

mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
Expand All @@ -402,6 +408,7 @@ def apply_int4_weight_only_quant(weight):
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)

return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
Expand All @@ -412,6 +419,9 @@ def int8_weight_only():
Applies int8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -422,6 +432,8 @@ def apply_int8wo_quant(weight):
return _get_linear_subclass_inserter(apply_int8wo_quant)

def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
Expand All @@ -441,6 +453,8 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
if in_features <= 16:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
Expand All @@ -465,6 +479,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


Expand All @@ -480,6 +495,8 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
quantize_affine,
dequantize_affine,
)
from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import to_affine_quantized
from torchao.quantization.quant_api import _get_linear_subclass_inserter
def apply_uintx_weight_only_quant(weight):

Expand Down

0 comments on commit 8b7b538

Please sign in to comment.