Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8PackingFormat,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import torch_version_at_least


def _make_cfg(act: str, target_sparsity: float = 0.90):
"""
Helper to build the v2 CSR config:
- act == "sym" -> dynamic int8 symmetric per-token
- act == "asym" -> dynamic uint8 asymmetric per-token
- act == "noop" -> weight-only decode (no activation quant)
"""
if act == "noop":
return Int8DynamicActivationInt8WeightConfig(
act_mapping_type=MappingType.SYMMETRIC, # ignored when weight_only_decode=True
weight_only_decode=True,
version=2,
int8_packing_format=Int8PackingFormat.CSR_SPARSE,
target_sparsity=target_sparsity,
)
elif act == "sym":
return Int8DynamicActivationInt8WeightConfig(
act_mapping_type=MappingType.SYMMETRIC,
weight_only_decode=False,
version=2,
int8_packing_format=Int8PackingFormat.CSR_SPARSE,
target_sparsity=target_sparsity,
)
elif act == "asym":
return Int8DynamicActivationInt8WeightConfig(
act_mapping_type=MappingType.ASYMMETRIC,
weight_only_decode=False,
version=2,
int8_packing_format=Int8PackingFormat.CSR_SPARSE,
target_sparsity=target_sparsity,
)
else:
raise ValueError(f"Unknown act mode: {act}")


CPU_DTYPES = [torch.float32] # CSR fallback path is CPU in your implementation


@unittest.skipIf(not torch_version_at_least("2.7.0"), "Need PyTorch 2.7+")
class TestInt8CsrSparseTensor(TestCase):
@parametrize("act_mode", ["sym", "asym", "noop"])
@parametrize(
"sizes",
[
((128,), 256, 128), # (M,), N, K
((32, 64), 512, 256), # (B, T), N, K
((2, 8, 16), 384, 192), # (B, T, ?), N, K
],
)
@parametrize("dtype", CPU_DTYPES)
def test_linear_forward_cpu(self, act_mode, sizes, dtype):
"""
Forward should run, produce finite values, and keep shapes consistent.
"""
M, N, K = sizes
x = torch.randn(*M, K, dtype=dtype, device="cpu")
lin = torch.nn.Linear(K, N, bias=True, dtype=dtype, device="cpu")

# fp32 reference
y_ref = lin(x)

cfg = _make_cfg(act_mode, target_sparsity=0.90)
quantize_(lin, cfg)

# weight must be our subclass
self.assertEqual(
str(type(lin.weight)),
"<class 'torchao.quantization.Int8CsrSparseTensor'>",
)

y_q = lin(x)
self.assertEqual(y_q.shape, y_ref.shape)
self.assertTrue(torch.isfinite(y_q).all(), "Quantized output has NaN/Inf")

# Sanity: expect some difference from fp32 (not required to be large)
diff = (y_q - y_ref).abs().mean()
self.assertTrue(torch.isfinite(diff))
self.assertGreaterEqual(diff.item(), 0.0)

@parametrize("act_mode", ["sym", "asym", "noop"])
def test_module_path_state_dict(self, act_mode):
"""
Saving state_dict and loading it back preserves the subclass type
of the weight tensor.
"""
K, N = 128, 256
lin = torch.nn.Linear(K, N, bias=True, dtype=torch.float32, device="cpu")
cfg = _make_cfg(act_mode, target_sparsity=0.85)
quantize_(lin, cfg)

self.assertEqual(
str(type(lin.weight)),
"<class 'torchao.quantization.Int8CsrSparseTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(lin.state_dict(), f)
f.seek(0)
sd = torch.load(f)
self.assertEqual(
str(type(sd["weight"])),
"<class 'torchao.quantization.Int8CsrSparseTensor'>",
)

def test_guard_small_in_features(self):
"""
If you keep the v1 guard (in_features <= 16) anywhere in your path,
ensure v2 config still quantizes (or update this accordingly).
Here we use K=32 to avoid hitting the guard.
"""
K, N = 32, 64
x = torch.randn(4, K)
lin = torch.nn.Linear(K, N)
cfg = _make_cfg("sym", target_sparsity=0.9)
quantize_(lin, cfg)
y = lin(x)
self.assertEqual(y.shape, (4, N))
self.assertTrue(torch.isfinite(y).all())


instantiate_parametrized_tests(TestInt8CsrSparseTensor)


if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8CsrSparseTensor,
Int8PackingFormat,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -165,6 +167,8 @@
"Int4PlainInt32Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Int8CsrSparseTensor",
"Int8PackingFormat",
"IntxOpaqueTensor",
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
Expand Down
32 changes: 32 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8CsrSparseTensor,
Int8PackingFormat,
IntxOpaqueTensor,
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
Expand Down Expand Up @@ -1514,6 +1516,9 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
weight_only_decode: bool = False
set_inductor_config: bool = True
version: int = 2
int8_packing_format: Int8PackingFormat = Int8PackingFormat.CSR_SPARSE
target_sparsity: float = 7.0

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -1540,7 +1545,34 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
f" because `in_feature` is <= 16: {in_features}"
)
return weight
if config.version == 2:
block_size = [1, in_features]
if config.int8_packing_format == Int8PackingFormat.CSR_SPARSE:
if weight_only_decode:
act_mode = "noop"
else:
act_mode = (
"int8_sym_per_token"
if act_mapping_type == MappingType.SYMMETRIC
else "int8_asym_per_token"
)

new_weight = Int8CsrSparseTensor.from_hp(
weight,
block_size=block_size,
act_mode=act_mode,
target_sparsity=config.target_sparsity,
)
return new_weight
else:
print(
"Unsupported packing format for version 2 of Int8DynamicActivationInt8WeightConfig, only Int8PackingFormat.CSR_SPARSE is supported"
)
# version 1
assert config.version == 1
warnings.warn(
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2948 for more details"
)
# weight settings
mapping_type = MappingType.SYMMETRIC
weight_zero_point_domain = ZeroPointDomain.NONE
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Int4Tensor,
)
from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor
from .int8.int8_csr_sparse_tensor import Int8CsrSparseTensor
from .int8.int8_packing_format import Int8PackingFormat
from .intx.intx_opaque_tensor import (
IntxOpaqueTensor,
)
Expand All @@ -44,4 +46,6 @@
"IntxPackingFormat",
"IntxUnpackedToInt8Tensor",
"IntxOpaqueTensor",
"Int8CsrSparseTensor",
"Int8PackingFormat",
]
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/int8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .int8_csr_sparse_tensor import Int8CsrSparseTensor
from .int8_packing_format import Int8PackingFormat

__all__ = ["Int8CsrSparseTensor", "Int8PackingFormat"]
Loading
Loading