Skip to content

Commit

Permalink
Implement sparsity as a AQT Layout (#498)
Browse files Browse the repository at this point in the history
Summary:

This PR adds in sparsity as an AQTLayout, previously it was implemented using the QuantizedLinearBase subclass that will be deprecated shortly. 

I also added renamed `sparsify` to `sparsify_` and added in a `semi_sparse_weight()` function to be in line with our other APIs. 

The main code changes are in `torchao/dtypes/affine_quantized_tensor.py`, for the semi-structured cusparselt representation, we can reuse a lot of the existing PlainLayout implementation, since the compressed representation is stored in a single tensor like `int_data`. 

Test Plan:
```
python test/sparsity/test_sparse_api
```
  • Loading branch information
jcaip authored and Hanxian97 committed Jul 29, 2024
1 parent ec99f95 commit 4303ef2
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 383 deletions.
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,18 @@ And a quick crash course on inference quantization to help parse the above table

Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
```python
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured
from torchao.sparsity import sparsify, semi_sparse_weight()

m = sparsify(m, to_sparse_semi_structured)
m = sparsify_(m, semi_sparse_weight())
```
Sparsity can also be composed with int8 dynamic quantization for further speedups:

```python
from torchao.sparsity import sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight

m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight())
```
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.

The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:
Expand Down
1 change: 0 additions & 1 deletion scripts/sam/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse

44 changes: 16 additions & 28 deletions scripts/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import time
import resource

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight
from torchao.utils import unwrap_tensor_subclass

torch._dynamo.config.cache_size_limit = 50000

def unbind_jagged(device, data, sizes, offsets):
Expand Down Expand Up @@ -279,30 +283,17 @@ def run(
block.attn.use_rel_pos = use_rel_pos

if compress == "int8_dynamic_quant":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only)
elif compress == "sparse":
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
elif compress == "int8_dynamic_quant_sparse":
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = False
from torchao.sparsity import sparsify, apply_fake_sparsity
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass

def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
Expand All @@ -316,20 +307,17 @@ def mlp_only(mod, name):
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)

quantize_(
predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only
)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_semi_sparse_weight(),
mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
mlp_lin2_only)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
int8_dynamic_activation_int8_2x4_sparse_weight(),
mlp_lin1_only, prune=False)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
to_sparse_semi_structured,
mlp_lin2_only, prune=False)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down Expand Up @@ -413,6 +401,6 @@ def mlp_only(mod, name):
vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
f.write(vals+"\n")

if __name__ == '__main__':
fire.Fire(run)
10 changes: 5 additions & 5 deletions scripts/sam/results.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
27 changes: 16 additions & 11 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import copy
import logging
import unittest

import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured

from torchao.sparsity import apply_fake_sparsity, sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.sparsity import (
apply_fake_sparsity,
sparsify_,
int8_dynamic_activation_int8_semi_sparse_weight,
semi_sparse_weight,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
int8_dynamic_activation_int8_weight,
quantize_,
)
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass
from torch.testing._internal.common_utils import TestCase


Expand All @@ -38,12 +44,11 @@ def test_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

model = sparsify(model, to_sparse_semi_structured)
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
Expand All @@ -58,15 +63,15 @@ def test_quant_semi_sparse(self):
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)
model_copy = copy.deepcopy(model)
quantize_(model_copy, int8_dynamic_activation_int8_weight())
dense_result = model_copy(input)

sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
to_affine_quantized_static,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
)

Expand All @@ -19,5 +20,6 @@
"to_affine_quantized_static",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
]
77 changes: 77 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@
class PlainLayoutType(LayoutType):
pass

@dataclass(frozen=True)
class SemiSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp


@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8
Expand Down Expand Up @@ -473,6 +484,47 @@ def from_plain(
assert isinstance(layout_type, PlainLayoutType)
return cls(int_data, scale, zero_point, layout_type)

@register_layout_cls(SemiSparseLayoutType)
class SemiSparseAQTLayout(PlainAQTLayout):
"""
Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor
"""
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

raise NotImplementedError(
f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported"
)

def get_plain(self):
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# the identity matrix to get the original dense matrix. This is slow though.
cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0])
int_data_expanded = torch._cslt_sparse_mm(self.int_data,
torch.eye(cols,
dtype=self.int_data.dtype,
device=self.int_data.device).t())
return int_data_expanded, self.scale, self.zero_point

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
assert isinstance(layout_type, SemiSparseLayoutType)
int_data_compressed = torch._cslt_compress(int_data)
return cls(int_data_compressed, scale, zero_point, layout_type)


@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
"""
Expand Down Expand Up @@ -669,6 +721,31 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
if bias is not None:
y += bias
return y
# handle int8 dynamic_quant + semi_structured_sparse
elif(
is_cuda and
input_is_int8 and
input_tensor.dtype == weight_qtensor.dtype and
isinstance(input_tensor.layout_type, PlainLayoutType) and
isinstance(weight_qtensor.layout_type, SemiSparseLayoutType)
):
x_vals_int8 = input_tensor.layout_tensor.int_data
x_scales = input_tensor.layout_tensor.scale
w_vals_int8 = weight_qtensor.layout_tensor.int_data
w_scales = weight_qtensor.layout_tensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16
).t()
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
)
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
else:
input_tensor = input_tensor.dequantize()

Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
]
26 changes: 22 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
come along with it and because that is how we access the intended quantized
and mixed GEMM kernels
"""

from functools import partial
import torch
import torchao
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional

from torchao.dtypes import PlainLayoutType
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -57,6 +58,7 @@
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
]
Expand Down Expand Up @@ -410,7 +412,8 @@ def apply_int8wo_quant(weight):

return _get_linear_subclass_inserter(apply_int8wo_quant)

def int8_dynamic_activation_int8_weight():

def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
quantization to linear layers
Expand All @@ -432,16 +435,31 @@ def get_weight_block_size(x):
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)


def int8_dynamic_activation_int8_semi_sparse_weight():
"""
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
11 changes: 9 additions & 2 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_fake_sparsity, sparsify
from .sparse_api import (
apply_fake_sparsity,
sparsify_,
semi_sparse_weight,
int8_dynamic_activation_int8_semi_sparse_weight
)

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver",
"apply_fake_sparsity",
"sparsify"
"sparsify_"
"semi_sparse_weight",
"int8_dynamic_activation_int8_semi_sparse_weight"
]
Loading

0 comments on commit 4303ef2

Please sign in to comment.