Skip to content

Commit

Permalink
Add sparsify API to torchao (#473)
Browse files Browse the repository at this point in the history
* Add sparsify API to torchao

* fix typo
  • Loading branch information
jcaip authored Jul 5, 2024
1 parent a2e8e2a commit a35a1cd
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 34 deletions.
38 changes: 27 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,41 @@ For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3b

And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`.

Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
```python
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured

#### With intrusive code changes
m = sparsify(m, to_sparse_semi_structured)
```
Sparsity can also be composed with int8 dynamic quantization for further speedups:

In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
```python
from torchao.sparsity import sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight

* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything).
* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`.
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
```
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.

The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:

| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline |
| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** |
| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** |
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** |
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** |
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** |
| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | |
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** |
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** |
| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** |
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** |

To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md).

The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32
#### With intrusive code changes

In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.

* 8x with in speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) (9.5x with int8 dynamic quantization + 2:4 sparsity)
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)

Expand Down
30 changes: 20 additions & 10 deletions scripts/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,20 @@ def run(
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only)
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
elif compress == "sparse":
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(predictor.model.image_encoder)
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
elif compress == "int8_dynamic_quant_sparse":
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight
from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = False
from torchao.sparsity import sparsify, apply_fake_sparsity
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass

Expand All @@ -306,6 +312,7 @@ def mlp_lin2_only(mod, name):
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name

# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)

Expand All @@ -314,10 +321,13 @@ def mlp_only(mod, name):
attn_only)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

predictor.model.image_encoder = quantize(predictor.model.image_encoder,
Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float,
mlp_lin1_only)
apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
int8_dynamic_activation_int8_2x4_sparse_weight(),
mlp_lin1_only, prune=False)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
to_sparse_semi_structured,
mlp_lin2_only, prune=False)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
9 changes: 5 additions & 4 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured

from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from torchao.sparsity import apply_fake_sparsity, sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
Expand Down Expand Up @@ -37,7 +38,7 @@ def test_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

apply_sparse_semi_structured(model)
model = sparsify(model, to_sparse_semi_structured)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
Expand All @@ -61,7 +62,7 @@ def test_quant_semi_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)
Expand Down
4 changes: 2 additions & 2 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity
from .sparse_api import apply_fake_sparsity, sparsify

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver",
"apply_sparse_semi_structured",
"apply_fake_sparsity",
"sparsify"
]
3 changes: 3 additions & 0 deletions torchao/sparsity/prototype/dynamic_quant_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,6 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
input_float.shape,
dtype=input_float.dtype,
)

def int8_dynamic_activation_int8_2x4_sparse_weight():
return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float
54 changes: 47 additions & 7 deletions torchao/sparsity/sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Callable, Optional

import torch
from torch.ao.pruning import WeightNormSparsifier
from torch.sparse import to_sparse_semi_structured
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import (
_is_linear,
_replace_with_custom_fn_if_matches_filter,
_get_linear_subclass_inserter,
)

# Sparsity helper functions
def apply_fake_sparsity(model, **kwargs):
Expand All @@ -24,10 +30,44 @@ def apply_fake_sparsity(model, **kwargs):
sparsifier.squash_mask()


def apply_sparse_semi_structured(model, **kwargs):
filter_fn = kwargs.pop("filter_fn", _is_linear)
def sparsify(model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
This function is essentially the same as quantize, put for sparsity subclasses.
apply_fake_sparsity(model, filter_fn=filter_fn)
for name, mod in model.named_modules():
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
Currently, we support two options for sparsity:
- semi-structured (2:4) sparsity with `to_sparse_semi_structured`
- int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API
Args:
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module
Example::
import torch
import torch.nn as nn
from torchao.sparsity import sparsify
def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
# for 2:4 sparsity
from torch.sparse import to_sparse_semi_structured
m = sparsify(m, to_sparse_semi_structured, filter_fn)
# for int8 dynamic quantization + 2:4 sparsity
from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn)
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(apply_tensor_subclass),
_is_linear if filter_fn is None else filter_fn,
)

return model

0 comments on commit a35a1cd

Please sign in to comment.