Skip to content

Commit 8deb14d

Browse files
add CSR tensor subclass and tests
Co-authored-by: Akash Agrawal <Akash.Agrawal@fujitsu.com>
1 parent 236c615 commit 8deb14d

File tree

11 files changed

+474
-280
lines changed

11 files changed

+474
-280
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import (
19+
Int8DynamicActivationInt8WeightConfig,
20+
Int8PackingFormat,
21+
quantize_,
22+
)
23+
from torchao.quantization.quant_primitives import MappingType
24+
from torchao.utils import torch_version_at_least
25+
26+
27+
def _make_cfg(act: str, target_sparsity: float = 0.90):
28+
"""
29+
Helper to build the v2 CSR config:
30+
- act == "sym" -> dynamic int8 symmetric per-token
31+
- act == "asym" -> dynamic uint8 asymmetric per-token
32+
- act == "noop" -> weight-only decode (no activation quant)
33+
"""
34+
if act == "noop":
35+
return Int8DynamicActivationInt8WeightConfig(
36+
act_mapping_type=MappingType.SYMMETRIC, # ignored when weight_only_decode=True
37+
weight_only_decode=True,
38+
version=2,
39+
int8_packing_format=Int8PackingFormat.CSR_SPARSE,
40+
target_sparsity=target_sparsity,
41+
)
42+
elif act == "sym":
43+
return Int8DynamicActivationInt8WeightConfig(
44+
act_mapping_type=MappingType.SYMMETRIC,
45+
weight_only_decode=False,
46+
version=2,
47+
int8_packing_format=Int8PackingFormat.CSR_SPARSE,
48+
target_sparsity=target_sparsity,
49+
)
50+
elif act == "asym":
51+
return Int8DynamicActivationInt8WeightConfig(
52+
act_mapping_type=MappingType.ASYMMETRIC,
53+
weight_only_decode=False,
54+
version=2,
55+
int8_packing_format=Int8PackingFormat.CSR_SPARSE,
56+
target_sparsity=target_sparsity,
57+
)
58+
else:
59+
raise ValueError(f"Unknown act mode: {act}")
60+
61+
62+
CPU_DTYPES = [torch.float32] # CSR fallback path is CPU in your implementation
63+
64+
65+
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Need PyTorch 2.7+")
66+
class TestInt8CsrSparseTensor(TestCase):
67+
@parametrize("act_mode", ["sym", "asym", "noop"])
68+
@parametrize(
69+
"sizes",
70+
[
71+
((128,), 256, 128), # (M,), N, K
72+
((32, 64), 512, 256), # (B, T), N, K
73+
((2, 8, 16), 384, 192), # (B, T, ?), N, K
74+
],
75+
)
76+
@parametrize("dtype", CPU_DTYPES)
77+
def test_linear_forward_cpu(self, act_mode, sizes, dtype):
78+
"""
79+
Forward should run, produce finite values, and keep shapes consistent.
80+
"""
81+
M, N, K = sizes
82+
x = torch.randn(*M, K, dtype=dtype, device="cpu")
83+
lin = torch.nn.Linear(K, N, bias=True, dtype=dtype, device="cpu")
84+
85+
# fp32 reference
86+
y_ref = lin(x)
87+
88+
cfg = _make_cfg(act_mode, target_sparsity=0.90)
89+
quantize_(lin, cfg)
90+
91+
# weight must be our subclass
92+
self.assertEqual(
93+
str(type(lin.weight)),
94+
"<class 'torchao.quantization.Int8CsrSparseTensor'>",
95+
)
96+
97+
y_q = lin(x)
98+
self.assertEqual(y_q.shape, y_ref.shape)
99+
self.assertTrue(torch.isfinite(y_q).all(), "Quantized output has NaN/Inf")
100+
101+
# Sanity: expect some difference from fp32 (not required to be large)
102+
diff = (y_q - y_ref).abs().mean()
103+
self.assertTrue(torch.isfinite(diff))
104+
self.assertGreaterEqual(diff.item(), 0.0)
105+
106+
@parametrize("act_mode", ["sym", "asym", "noop"])
107+
def test_module_path_state_dict(self, act_mode):
108+
"""
109+
Saving state_dict and loading it back preserves the subclass type
110+
of the weight tensor.
111+
"""
112+
K, N = 128, 256
113+
lin = torch.nn.Linear(K, N, bias=True, dtype=torch.float32, device="cpu")
114+
cfg = _make_cfg(act_mode, target_sparsity=0.85)
115+
quantize_(lin, cfg)
116+
117+
self.assertEqual(
118+
str(type(lin.weight)),
119+
"<class 'torchao.quantization.Int8CsrSparseTensor'>",
120+
)
121+
122+
with tempfile.NamedTemporaryFile() as f:
123+
torch.save(lin.state_dict(), f)
124+
f.seek(0)
125+
sd = torch.load(f)
126+
self.assertEqual(
127+
str(type(sd["weight"])),
128+
"<class 'torchao.quantization.Int8CsrSparseTensor'>",
129+
)
130+
131+
def test_guard_small_in_features(self):
132+
"""
133+
If you keep the v1 guard (in_features <= 16) anywhere in your path,
134+
ensure v2 config still quantizes (or update this accordingly).
135+
Here we use K=32 to avoid hitting the guard.
136+
"""
137+
K, N = 32, 64
138+
x = torch.randn(4, K)
139+
lin = torch.nn.Linear(K, N)
140+
cfg = _make_cfg("sym", target_sparsity=0.9)
141+
quantize_(lin, cfg)
142+
y = lin(x)
143+
self.assertEqual(y.shape, (4, N))
144+
self.assertTrue(torch.isfinite(y).all())
145+
146+
147+
instantiate_parametrized_tests(TestInt8CsrSparseTensor)
148+
149+
150+
if __name__ == "__main__":
151+
run_tests()

torchao/dtypes/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .nf4tensor import NF4Tensor, to_nf4
1616
from .uintx import (
1717
BlockSparseLayout,
18-
CSRLayout,
1918
CutlassInt4PackedLayout,
2019
Int4CPULayout,
2120
Int4XPULayout,
@@ -48,7 +47,6 @@
4847
"Layout",
4948
"PlainLayout",
5049
"SemiSparseLayout",
51-
"CSRLayout",
5250
"TensorCoreTiledLayout",
5351
"Float8Layout",
5452
"MarlinSparseLayout",

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929
_linear_int8_act_int8_weight_block_sparse_check,
3030
_linear_int8_act_int8_weight_block_sparse_impl,
3131
)
32-
from torchao.dtypes.uintx.csr_layout import (
33-
_linear_int8_act_int8_weight_csr_sparse_check,
34-
_linear_int8_act_int8_weight_csr_sparse_impl,
35-
)
3632
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
3733
_linear_int4_act_int4_weight_cutlass_check,
3834
_linear_int4_act_int4_weight_cutlass_impl,
@@ -195,10 +191,6 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
195191
def _register_aqt_quantized_linear_dispatches():
196192
for dispatch_condition, impl in [
197193
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
198-
(
199-
_linear_int8_act_int8_weight_csr_sparse_check,
200-
_linear_int8_act_int8_weight_csr_sparse_impl,
201-
),
202194
(
203195
_linear_int8_act_int8_weight_semi_structured_sparse_check,
204196
_linear_int8_act_int8_weight_semi_structured_sparse_impl,

torchao/dtypes/uintx/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from .block_sparse_layout import (
22
BlockSparseLayout,
33
)
4-
from .csr_layout import (
5-
CSRLayout,
6-
)
74
from .cutlass_int4_packed_layout import (
85
CutlassInt4PackedLayout,
96
)
@@ -45,7 +42,6 @@
4542
"BlockSparseLayout",
4643
"MarlinSparseLayout",
4744
"SemiSparseLayout",
48-
"CSRLayout",
4945
"TensorCoreTiledLayout",
5046
"Int4CPULayout",
5147
"MarlinQQQLayout",

0 commit comments

Comments
 (0)