Skip to content

Conversation

choudhary-devang
Copy link
Collaborator

@choudhary-devang choudhary-devang commented Sep 10, 2025

Introduces a CSR layout for affine-quantized INT8 weights so linear layers can run with sparse weights in TorchAO.

Adds per-layer sparsity control via CSRLayout(target_sparsity=...) (falls back to env var TORCHAO_CSR_TARGET_SPARSITY or 0.9). This makes experimentation deterministic and tunable per layer.

Key changes

  • New CSRLayout and CSR_AQTTensorImpl:

    • Stores weights as torch.sparse_csr_tensor(int8) + scale/zero-point.

    • pre_process() does magnitude pruning to the requested sparsity, then packs to CSR.

  • Minimal dispatch: intercepts aten.linear (and addmm/mm) when the weight is CSR-packed and runs a CSR path.

  • Reference CPU fallback uses sparse matmul + dequantization for correctness; can be replaced with vendor INT8 SpMM later.

Usage

import torch, torch.nn as nn
from torchao.quantization import quantize_
from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight
from torchao.dtypes.uintx.csr_layout import CSRLayout, CSR_AQTTensorImpl

m = nn.Linear(64, 128, bias=False).eval()
cfg = int8_dynamic_activation_int8_weight(layout=CSRLayout(target_sparsity=0.85))
quantize_(m, cfg)

inner = m.weight.original_weight_tensor
assert isinstance(inner.tensor_impl, CSR_AQTTensorImpl)  # packed as CSR
y = m(torch.randn(4, 64))

this pr is a joint contribution by:

@agrawal-aka
@choudhary-devang

cc: @jerryzh168, @jcaip

Copy link

pytorch-bot bot commented Sep 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2971

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit 6a1b942 with merge base eadead5 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 10, 2025
@choudhary-devang choudhary-devang force-pushed the CSR_Layout branch 2 times, most recently from aa68b31 to c2562be Compare September 10, 2025 09:05
@choudhary-devang choudhary-devang added the quantize_ quantize_ API label Sep 10, 2025
@choudhary-devang choudhary-devang added the topic: new feature Use this tag if this PR adds a new feature label Sep 10, 2025
@agrawal-aka
Copy link
Contributor

@pytorchbot label "sparsity"

@pytorch-bot pytorch-bot bot added the sparsity label Sep 10, 2025
@jcaip jcaip self-requested a review September 10, 2025 13:27
Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs some tests and benchmarks (I guess that's also dependent on the kernel) but it's pretty cool. Thanks for the work @agrawal-aka @choudhary-devang!

The only thing is I think that we are changing the AQT internals to move away from Layout, do you have thoughts on that @jerryzh168?

x2d = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).to(torch.int32)

# As for now we don't have any kernel for spmm form int8 computation for ARM so we are upscaling it to fp32 for computation
# Once we intigrate the kernel in torch then we can use that kernel instead of torch.mm()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Is there a PR open in pytorch for the ARM kernel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also might be a good idea to throw a warning if we use the upscaled fp32 mm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Is there a PR open in pytorch for the ARM kernel?
we are currently working on it, will raise it in some time in future, mean while this change will Enable the API so that kernel can be patched later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also might be a good idea to throw a warning if we use the upscaled fp32 mm.

sure, we will add the warning as suggested.


# As for now we don't have any kernel for spmm form int8 computation for ARM so we are upscaling it to fp32 for computation
# Once we intigrate the kernel in torch then we can use that kernel instead of torch.mm()
y_int32 = torch.mm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the var names here could be a bit clearer - why is this y_int32? torch.mm should output float32 right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, we will change the var name accordingly

target = self.target_sparsity
else:
# 2. fall back to env var or 0.9 default
target = float(os.getenv("TORCHAO_CSR_TARGET_SPARSITY", "0.9"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an ENV variable for this? I think maybe it would be better to just pass this in the Config.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, we can remove the ENV variable as suggested.

@jerryzh168
Copy link
Contributor

The only thing is I think that we are changing the AQT internals to move away from Layout, do you have thoughts on that @jerryzh168?

yeah we are moving away from AQT, we have updated the official doc: https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

new code is located in https://github.com/pytorch/ao/tree/main/torchao/quantization/quantize_/workflows, example sparse tensor for int4: https://github.com/pytorch/ao/blob/main/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py

@choudhary-devang choudhary-devang force-pushed the CSR_Layout branch 8 times, most recently from b5646da to fab3a7b Compare September 22, 2025 10:05
Co-authored-by: Akash Agrawal <Akash.Agrawal@fujitsu.com>
@choudhary-devang choudhary-devang force-pushed the CSR_Layout branch 3 times, most recently from 6f6d477 to 8deb14d Compare September 24, 2025 07:02
Co-authored-by: Akash Agrawal <Akash.Agrawal@fujitsu.com>
shape: Tuple[int, int],
block_size: Tuple[int, int],
mapping: MappingType,
a8_mode: str = "int8_sym_per_token", # "noop" | "int8_sym_per_token" | "int8_asym_per_token"
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is optional please move a8_mode to optional_tensor_attribute_names for the class, like this:

optional_tensor_attribute_names = ["optional_attr"]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure



# ---- alias: return self unchanged (avoid touching strides) ----
for _name in ("default", "default_1", "default_2", "default_3"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these?

Copy link
Collaborator Author

@choudhary-devang choudhary-devang Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI failure which I am facing:-

FAILED test/integration/test_integration.py::TestExport::test_export_06 - RuntimeError: Sparse CSR tensors do not have strides
FAILED test/integration/test_integration.py::TestExport::test_export_07 - RuntimeError: Sparse CSR tensors do not have strides
FAILED test/integration/test_integration.py::TestExport::test_export_08 - RuntimeError: Sparse CSR tensors do not have strides
FAILED test/integration/test_integration.py::TestExport::test_export_09 - RuntimeError: Sparse CSR tensors do not have strides main() File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 98, in main run_cmd_or_die(f"docker exec -t {container_name} /exec")
File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 39, in run_cmd_or_die raise RuntimeError(f"Command {cmd} failed with exit code {exit_code}") RuntimeError: Command docker exec -t 0f8869980cebe3afe2ab9fde5d2a8c1619d75c515eed526eff53aad9bb36a12e /exec failed with exit code 1

FAILED test/integration/test_integration.py::TestExport::test_export_10 - RuntimeError: Sparse CSR tensors do not have strides FAILED test/integration/test_integration.py::TestExport::test_export_11 - RuntimeError: Sparse CSR tensors do not have strides
===== 9 failed, 144 passed, 219 skipped, 110 warnings in 367.13s (0:06:07) =====

Those names (default, default_1, …) are the codegen’d overload entry points that PyTorch emits for some ops across versions (2.3 → nightlies). On a given install you might see only default, but on others there can be additional default_* variants. I iterated them to keep the override stable across minor/patch releases and CI images, instead of pinning to just one symbol and risking a missed dispatch.

That said, the intent here is simple: make aten.alias(*) a no-op for our wrapper so we don’t touch strides (CSR has none), which fixes the export tests.

I can make this clearer and safer by:

Replacing the hardcoded tuple with a tiny helper that discovers available overloads dynamically,

# Register against all available alias overloads (different PyTorch builds
# may expose default, default_1, ...). We keep alias a no-op to avoid
# any view/stride logic on a CSR-backed wrapper.
def _overloads(op_ns, name_prefix="default"):
    return [
        getattr(op_ns, k) for k in dir(op_ns)
        if k.startswith(name_prefix)
    ]

for _op in _overloads(aten.alias):
    @implements([_op])
    def _alias_noop(func, types, args, kwargs):
        self = args[0]
        return self.detach()

If you want, I’ll apply the same discovery pattern (and comment) to access_subclass_inner_tensor and alias_copy for consistency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. quantize_ quantize_ API sparsity topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants