Skip to content

Commit

Permalink
refactor quant api
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Jul 25, 2024
1 parent a2af83d commit ce567c3
Showing 1 changed file with 39 additions and 45 deletions.
84 changes: 39 additions & 45 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional

from torchao.dtypes.utils import LayoutType
from torchao.dtypes import PlainLayoutType
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -412,53 +412,48 @@ def apply_int8wo_quant(weight):

return _get_linear_subclass_inserter(apply_int8wo_quant)

def _apply_int8_dynamic_activation_int8_weight_quant(weight : torch.Tensor, layout_type : LayoutType) -> torch.Tensor:
"""
Helper function to specify layout_type for int8 dynamic activation int8 dynamic weight quantization.
Used to compose with semi-structured sparsity.
"""
in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
if in_features <= 16:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

def int8_dynamic_activation_int8_weight():
def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
quantization to linear layers
"""
from torchao.dtypes import PlainLayoutType
_apply_int8_dynamic_activation_int8_weight_quant_layout = partial(_apply_int8_dynamic_activation_int8_weight_quant, layout_type=PlainLayoutType())
return _get_linear_subclass_inserter(_apply_int8_dynamic_activation_int8_weight_quant_layout)
def apply_int8_dynamic_activation_int8_weight_quant(weight):
in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
if in_features <= 16:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)


def int8_dynamic_activation_int8_semi_sparse_weight():
Expand All @@ -467,5 +462,4 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
quantization + 2:4 sparsity to linear layers.
"""
from torchao.dtypes import SemiSparseLayoutType
_apply_int8_dynamic_activation_int8_weight_quant_layout = partial(_apply_int8_dynamic_activation_int8_weight_quant, layout_type=SemiSparseLayoutType())
return _get_linear_subclass_inserter(_apply_int8_dynamic_activation_int8_weight_quant_layout)
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())

0 comments on commit ce567c3

Please sign in to comment.