-
Notifications
You must be signed in to change notification settings - Fork 342
Add CSR (Compressed Sparse Row) layout for INT8 AQT weights. #2971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2971
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New FailuresAs of commit 6a1b942 with merge base eadead5 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aa68b31
to
c2562be
Compare
c2562be
to
acf2616
Compare
@pytorchbot label "sparsity" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs some tests and benchmarks (I guess that's also dependent on the kernel) but it's pretty cool. Thanks for the work @agrawal-aka @choudhary-devang!
The only thing is I think that we are changing the AQT internals to move away from Layout, do you have thoughts on that @jerryzh168?
torchao/dtypes/uintx/csr_layout.py
Outdated
x2d = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).to(torch.int32) | ||
|
||
# As for now we don't have any kernel for spmm form int8 computation for ARM so we are upscaling it to fp32 for computation | ||
# Once we intigrate the kernel in torch then we can use that kernel instead of torch.mm() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Is there a PR open in pytorch for the ARM kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also might be a good idea to throw a warning if we use the upscaled fp32 mm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Is there a PR open in pytorch for the ARM kernel?
we are currently working on it, will raise it in some time in future, mean while this change will Enable the API so that kernel can be patched later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also might be a good idea to throw a warning if we use the upscaled fp32 mm.
sure, we will add the warning as suggested.
torchao/dtypes/uintx/csr_layout.py
Outdated
|
||
# As for now we don't have any kernel for spmm form int8 computation for ARM so we are upscaling it to fp32 for computation | ||
# Once we intigrate the kernel in torch then we can use that kernel instead of torch.mm() | ||
y_int32 = torch.mm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think the var names here could be a bit clearer - why is this y_int32? torch.mm should output float32 right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, we will change the var name accordingly
torchao/dtypes/uintx/csr_layout.py
Outdated
target = self.target_sparsity | ||
else: | ||
# 2. fall back to env var or 0.9 default | ||
target = float(os.getenv("TORCHAO_CSR_TARGET_SPARSITY", "0.9")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need an ENV variable for this? I think maybe it would be better to just pass this in the Config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, we can remove the ENV variable as suggested.
yeah we are moving away from AQT, we have updated the official doc: https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html new code is located in https://github.com/pytorch/ao/tree/main/torchao/quantization/quantize_/workflows, example sparse tensor for int4: https://github.com/pytorch/ao/blob/main/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py |
b5646da
to
fab3a7b
Compare
Co-authored-by: Akash Agrawal <Akash.Agrawal@fujitsu.com>
6f6d477
to
8deb14d
Compare
Co-authored-by: Akash Agrawal <Akash.Agrawal@fujitsu.com>
8deb14d
to
6a1b942
Compare
shape: Tuple[int, int], | ||
block_size: Tuple[int, int], | ||
mapping: MappingType, | ||
a8_mode: str = "int8_sym_per_token", # "noop" | "int8_sym_per_token" | "int8_asym_per_token" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is optional please move a8_mode
to optional_tensor_attribute_names
for the class, like this:
Line 310 in 3d48174
optional_tensor_attribute_names = ["optional_attr"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
|
||
|
||
# ---- alias: return self unchanged (avoid touching strides) ---- | ||
for _name in ("default", "default_1", "default_2", "default_3"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI failure which I am facing:-
FAILED test/integration/test_integration.py::TestExport::test_export_06 - RuntimeError: Sparse CSR tensors do not have strides
FAILED test/integration/test_integration.py::TestExport::test_export_07 - RuntimeError: Sparse CSR tensors do not have strides
FAILED test/integration/test_integration.py::TestExport::test_export_08 - RuntimeError: Sparse CSR tensors do not have strides
FAILED test/integration/test_integration.py::TestExport::test_export_09 - RuntimeError: Sparse CSR tensors do not have strides main() File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 98, in main run_cmd_or_die(f"docker exec -t {container_name} /exec")
File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 39, in run_cmd_or_die raise RuntimeError(f"Command {cmd} failed with exit code {exit_code}") RuntimeError: Command docker exec -t 0f8869980cebe3afe2ab9fde5d2a8c1619d75c515eed526eff53aad9bb36a12e /exec failed with exit code 1
FAILED test/integration/test_integration.py::TestExport::test_export_10 - RuntimeError: Sparse CSR tensors do not have strides FAILED test/integration/test_integration.py::TestExport::test_export_11 - RuntimeError: Sparse CSR tensors do not have strides
===== 9 failed, 144 passed, 219 skipped, 110 warnings in 367.13s (0:06:07) =====
Those names (default, default_1, …) are the codegen’d overload entry points that PyTorch emits for some ops across versions (2.3 → nightlies). On a given install you might see only default, but on others there can be additional default_*
variants. I iterated them to keep the override stable across minor/patch releases and CI images, instead of pinning to just one symbol and risking a missed dispatch.
That said, the intent here is simple: make aten.alias(*)
a no-op for our wrapper so we don’t touch strides (CSR has none), which fixes the export tests.
I can make this clearer and safer by:
Replacing the hardcoded tuple with a tiny helper that discovers available overloads dynamically,
# Register against all available alias overloads (different PyTorch builds
# may expose default, default_1, ...). We keep alias a no-op to avoid
# any view/stride logic on a CSR-backed wrapper.
def _overloads(op_ns, name_prefix="default"):
return [
getattr(op_ns, k) for k in dir(op_ns)
if k.startswith(name_prefix)
]
for _op in _overloads(aten.alias):
@implements([_op])
def _alias_noop(func, types, args, kwargs):
self = args[0]
return self.detach()
If you want, I’ll apply the same discovery pattern (and comment) to access_subclass_inner_tensor
and alias_copy
for consistency.
Introduces a CSR layout for affine-quantized INT8 weights so linear layers can run with sparse weights in TorchAO.
Adds per-layer sparsity control via CSRLayout(target_sparsity=...) (falls back to env var TORCHAO_CSR_TARGET_SPARSITY or 0.9). This makes experimentation deterministic and tunable per layer.
Key changes
New CSRLayout and CSR_AQTTensorImpl:
Stores weights as torch.sparse_csr_tensor(int8) + scale/zero-point.
pre_process() does magnitude pruning to the requested sparsity, then packs to CSR.
Minimal dispatch: intercepts aten.linear (and addmm/mm) when the weight is CSR-packed and runs a CSR path.
Reference CPU fallback uses sparse matmul + dequantization for correctness; can be replaced with vendor INT8 SpMM later.
Usage
this pr is a joint contribution by:
@agrawal-aka
@choudhary-devang
cc: @jerryzh168, @jcaip