diff --git a/README.md b/README.md
index e99f565900..f0cb7e2c00 100644
--- a/README.md
+++ b/README.md
@@ -52,7 +52,20 @@ And a quick crash course on inference quantization to help parse the above table
In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
-* 8x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai)
+* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything).
+* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`.
+
+| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
+|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
+| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline |
+| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** |
+| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** |
+| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** |
+| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** |
+| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** |
+
+The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32
+
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)
diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py
deleted file mode 100644
index 3a5d817602..0000000000
--- a/benchmarks/benchmark_sam.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import argparse
-from itertools import product
-
-import pandas as pd
-# to install segment-anything-fast you can run:
-# pip install git+https://github.com/pytorch-labs/segment-anything-fast.git
-from segment_anything_fast import sam_model_registry
-import torch
-from torch.utils.benchmark import Timer
-from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
-from torchao.quantization.quant_api import (
- _replace_with_custom_fn_if_matches_filter,
- _get_subclass_inserter,
- _is_linear,
- QuantizedLinearWeightBase,
- Int8DynamicallyQuantizedLinearWeight,
-)
-from torchao.quantization import change_linear_weights_to_int8_dqtensors
-from torchao.sparsity import (
- apply_sparse_semi_structured,
- apply_fake_sparsity,
-)
-from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
-from tqdm import tqdm
-
-sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
-model_type = 'vit_h'
-model_name = 'sam_vit_h_4b8939.pth'
-checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
-
-torch._inductor.config.epilogue_fusion = True
-torch._inductor.config.coordinate_descent_tuning = True
-torch._inductor.config.coordinate_descent_check_all_directions = True
-torch._inductor.config.force_fuse_int_mm_with_mul = True
-
-@torch.no_grad()
-def benchmark(f, *args, **kwargs):
- for _ in range(3):
- f(*args, **kwargs)
- torch.cuda.synchronize()
-
- torch.cuda.reset_peak_memory_stats()
- t0 = Timer(
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
- )
- res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
- return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}
-
-def get_sam_model(only_one_block=False, batchsize=1):
- sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
- model = sam.image_encoder.eval()
- image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')
-
- # code to use just a single block of the model
- if only_one_block:
- model = model.blocks[0]
- image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
- return model, image
-
-def qkv_only(mod, name):
- return isinstance(mod, torch.nn.Linear) and 'qkv' in name
-
-def proj_only(mod, name):
- return isinstance(mod, torch.nn.Linear) and 'proj' in name
-
-def lin1_only(mod, name):
- return isinstance(mod, torch.nn.Linear) and 'lin1' in name
-
-def lin2_only(mod, name):
- return isinstance(mod, torch.nn.Linear) and 'lin2' in name
-
-SUBCLASSES = {
- "quant" : Int8DynamicallyQuantizedLinearWeight,
- "quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight,
- "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight,
- "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS,
- "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT,
-}
-
-def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None):
- res = {
- "block_only": block_only,
- "batchsize": batchsize,
- "dtype": dtype,
- "compile": compile,
- "qkv" : qkv,
- "proj": proj,
- "lin1": lin1,
- "lin2": lin2,
- }
- with torch.no_grad():
- model, image = get_sam_model(block_only, batchsize)
- model = model.to(dtype)
- image = image.to(dtype)
-
- # 2:4 prune model
- apply_fake_sparsity(model)
- option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only])
-
- for option, filter_fn in option_and_filter_fn:
- subclass = SUBCLASSES.get(option, None)
- if subclass and issubclass(subclass, SparseSemiStructuredTensor):
- # replace with to_sparse_semi_structured
- for name, mod in model.named_modules():
- if filter_fn(mod, name):
- mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight))
- elif subclass and issubclass(subclass, QuantizedLinearWeightBase):
- _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn)
-
- if compile:
- model = torch.compile(model, mode='max-autotune')
-
- res.update(benchmark(model, image))
- res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize'])
- return res
-
-if __name__ == "__main__":
- print("BENCHMARKING")
- parser = argparse.ArgumentParser(description='Process some integers.')
- parser.add_argument('--eager', action='store_true', help='enable/disable torch.compile')
- args = parser.parse_args()
- # ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
- ALL_RUNS = [
- run_once(compile=not args.eager),
- run_once(compile=not args.eager, lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
- run_once(compile=not args.eager, lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
- run_once(compile=not args.eager, qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
- run_once(compile=not args.eager, qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
- # run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
- # run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
- # run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
- # run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
- # run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
- ]
- df = pd.DataFrame(ALL_RUNS)
- df.to_csv("sam_benchmark_results.csv")
- print(df)
diff --git a/scripts/sam/.gitignore b/scripts/sam/.gitignore
new file mode 100644
index 0000000000..3e7cc59d91
--- /dev/null
+++ b/scripts/sam/.gitignore
@@ -0,0 +1,3 @@
+tmp
+checkpoints
+datasets
diff --git a/scripts/sam/README.md b/scripts/sam/README.md
new file mode 100644
index 0000000000..426d7fe6a8
--- /dev/null
+++ b/scripts/sam/README.md
@@ -0,0 +1,21 @@
+# benchmarking instructions:
+
+Setup your enviornment with:
+```
+conda env create -n "saf-ao" python=3.10
+conda activate saf-ao
+pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
+pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git
+pip3 install tqdm fire pandas
+cd ../.. && python setup.py install
+```
+
+Then download data and models by running
+```
+sh setup.sh
+```
+
+Finally, you can run benchmarks with
+```
+sh benchmark_sam.sh
+```
diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh
new file mode 100755
index 0000000000..5c1262f9cc
--- /dev/null
+++ b/scripts/sam/benchmark.sh
@@ -0,0 +1,11 @@
+# baseline
+python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
+# int8 dynamic quant (all)
+python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
+# 2:4 sparsity (all)
+python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
+# 2:4 sparsity (mlp only)
+python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
+# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
+python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
+
diff --git a/scripts/sam/data.py b/scripts/sam/data.py
new file mode 100644
index 0000000000..20f7632652
--- /dev/null
+++ b/scripts/sam/data.py
@@ -0,0 +1,297 @@
+import torch
+import diskcache
+from pycocotools.coco import COCO
+import numpy as np
+from scipy import ndimage
+import skimage.io as io
+import skimage.color as color
+
+
+def _get_center_point(mask, ann_id, cache):
+ """
+ This is a rudimentary version of https://arxiv.org/pdf/2304.02643.pdf,
+ section D.1.Point Sampling
+
+ From the paper: "The first point is chosen deterministically as the point
+ farthest from the object boundary."
+
+ The code below is an approximation of this.
+
+ First, we try to calculate the center of mass. If it's inside the mask, we
+ stop here.
+
+ The centroid may be outside of the mask for some mask shapes. In this case
+ we do a slow hack, specifically, we check for the
+ minumum of the maximum distance from the boundary in four directions
+ (up, right, down, left), and take the point with the maximum of these
+ minimums. Note: this is not performant for large masks.
+
+ Returns the center point in (x, y) format
+ """
+ if ann_id in cache:
+ return cache[ann_id]
+
+ # try the center of mass, keep it if it's inside the mask
+ com_y, com_x = ndimage.center_of_mass(mask)
+ com_y, com_x = int(round(com_y, 0)), int(round(com_x, 0))
+ if mask[com_y][com_x]:
+ cache[ann_id] = (com_x, com_y)
+ return (com_x, com_y)
+
+ # if center of mass didn't work, do the slow manual approximation
+
+ # up, right, down, left
+ # TODO(future): approximate better by adding more directions
+ distances_to_check_deg = [0, 90, 180, 270]
+
+ global_min_max_distance = float('-inf')
+ global_coords = None
+ # For now, terminate early to speed up the calculation as long as
+ # the point sample is gooe enough. This sacrifices the quality of point
+ # sampling for speed. In the future we can make this more accurate.
+ DISTANCE_GOOD_ENOUGH_THRESHOLD = 20
+
+ # Note: precalculating the bounding box could be somewhat
+ # helpful, but checked the performance gain and it's not much
+ # so leaving it out to keep the code simple.
+ # Note: tried binary search instead of incrementing by one to
+ # travel up/right/left/down, but that does not handle masks
+ # with all shapes properly (there could be multiple boundaries).
+ for row_idx in range(mask.shape[0]):
+ for col_idx in range(mask.shape[1]):
+ cur_point = mask[row_idx, col_idx]
+
+ # skip points inside bounding box but outside mask
+ if not cur_point:
+ continue
+
+ max_distances = []
+ for direction in distances_to_check_deg:
+ # TODO(future) binary search instead of brute forcing it if we
+ # need a speedup, with the cache it doesn't really matter though
+ if direction == 0:
+ # UP
+ cur_row_idx = row_idx
+
+ while cur_row_idx >= 0 and mask[cur_row_idx, col_idx]:
+ cur_row_idx = cur_row_idx - 1
+ cur_row_idx += 1
+ distance = row_idx - cur_row_idx
+ max_distances.append(distance)
+
+ elif direction == 90:
+ # RIGHT
+ cur_col_idx = col_idx
+
+ while cur_col_idx <= mask.shape[1] - 1 and \
+ mask[row_idx, cur_col_idx]:
+ cur_col_idx += 1
+ cur_col_idx -= 1
+ distance = cur_col_idx - col_idx
+ max_distances.append(distance)
+
+ elif direction == 180:
+ # DOWN
+ cur_row_idx = row_idx
+ while cur_row_idx <= mask.shape[0] - 1 and \
+ mask[cur_row_idx, col_idx]:
+ cur_row_idx = cur_row_idx + 1
+ cur_row_idx -= 1
+ distance = cur_row_idx - row_idx
+ max_distances.append(distance)
+
+ elif direction == 270:
+ # LEFT
+ cur_col_idx = col_idx
+ while cur_col_idx >= 0 and mask[row_idx, cur_col_idx]:
+ cur_col_idx -= 1
+ cur_col_idx += 1
+ distance = col_idx - cur_col_idx
+ max_distances.append(distance)
+
+ min_max_distance = min(max_distances)
+ if min_max_distance > global_min_max_distance:
+ global_min_max_distance = min_max_distance
+ global_coords = (col_idx, row_idx)
+ if global_min_max_distance >= DISTANCE_GOOD_ENOUGH_THRESHOLD:
+ break
+
+ cache[ann_id] = global_coords
+ return global_coords
+
+
+def build_datapoint(imgId,
+ coco,
+ pixel_mean,
+ pixel_std,
+ coco_root_dir,
+ coco_slice_name,
+ catIds,
+ cache,
+ predictor,
+ pad_input_image_batch):
+ img = coco.loadImgs(imgId)[0]
+
+ file_location = f'{coco_root_dir}/{coco_slice_name}/{img["file_name"]}'
+ I = io.imread(file_location)
+ if len(I.shape) == 2:
+ # some images, like img_id==61418, are grayscale
+ # convert to RGB to ensure the rest of the pipeline works
+ I = color.gray2rgb(I)
+
+ # load and display instance annotations
+ annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
+ anns = coco.loadAnns(annIds)
+
+ # approximate the center point of each mask
+ coords_list = []
+ gt_masks_list = []
+ for ann in anns:
+ ann_id = ann['id']
+ mask = coco.annToMask(ann)
+ gt_masks_list.append(torch.tensor(mask))
+ coords = _get_center_point(mask, ann_id, cache)
+ coords_list.append(coords)
+
+ image = I
+
+ # predictor_set_image begin
+ # Transform the image to the form expected by the model
+ input_image = predictor.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image)
+ input_image_torch = input_image_torch.permute(
+ 2, 0, 1).contiguous()[None, :, :, :]
+ predictor_input_size = input_image_torch.shape[-2:]
+
+ # Preprocess
+ x = input_image_torch
+ # Normalize colors
+ x = (x - pixel_mean) / pixel_std
+
+ if pad_input_image_batch:
+ # Pad
+ h, w = x.shape[-2:]
+ padh = predictor.model.image_encoder.img_size - h
+ padw = predictor.model.image_encoder.img_size - w
+ x = torch.nn.functional.pad(x, (0, padw, 0, padh))
+ else:
+ x = x.squeeze(0)
+
+ gt_masks_list = torch.stack(gt_masks_list) if len(gt_masks_list) else None
+ return image, coords_list, gt_masks_list, anns, x, predictor_input_size
+
+
+def build_data(coco_img_ids,
+ coco,
+ catIds,
+ coco_root_dir,
+ coco_slice_name,
+ point_sampling_cache_dir,
+ predictor,
+ use_half,
+ pad_input_image_batch):
+ cache = diskcache.Cache(point_sampling_cache_dir)
+ # make sure you clear the cache if you change the point sampling algorithm
+ # cache.clear()
+
+ pixel_mean = predictor.model.pixel_mean.cpu()
+ pixel_std = predictor.model.pixel_std.cpu()
+
+ def build_batch(indicies):
+ batch = [[], [], [], [], [], [], [], [], [], [], []]
+ batch[3] = [0]
+ batch[6] = [0]
+ for img_idx in indicies:
+ imgId = coco_img_ids[img_idx]
+
+ datapoint = build_datapoint(imgId,
+ coco,
+ pixel_mean,
+ pixel_std,
+ coco_root_dir,
+ coco_slice_name,
+ catIds,
+ cache,
+ predictor,
+ pad_input_image_batch)
+ I, coords_list, gt_masks_list, anns, x, predictor_input_size = datapoint
+ if len(coords_list) == 0:
+ continue
+ batch[0].append(x)
+ # batch[0].append(x[0])
+ coords_list = predictor.transform.apply_coords(
+ np.array(coords_list), I.shape[:2])
+ coords_list = torch.tensor(coords_list, dtype=torch.float)
+
+ batch[1].append(coords_list.reshape(-1))
+ batch[2].append(coords_list.size())
+ batch[3].append(coords_list.numel() + batch[3][-1])
+
+ batch[4].append(gt_masks_list.reshape(-1))
+ batch[5].append(gt_masks_list.size())
+ batch[6].append(gt_masks_list.numel() + batch[6][-1])
+
+ batch[7].append(anns)
+ batch[8].append(I)
+ batch[9].append(predictor_input_size)
+ batch[10].append(img_idx)
+
+ def cat_and_cast(b, use_half):
+ b = torch.cat(b) if len(b) > 0 else None
+ if use_half is not None and b is not None:
+ return b.to(use_half)
+ return b
+
+ def to_nested_tensor(data, sizes=None, use_half=None):
+ if len(data) == 0:
+ return None
+ dtype = use_half if use_half is not None else torch.float32
+
+ if sizes is not None:
+ data = [d.view(s) for (d, s) in zip(data, sizes)]
+
+ return torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)
+
+ if pad_input_image_batch:
+ batch[0] = cat_and_cast(batch[0], use_half)
+ else:
+ batch[0] = to_nested_tensor(batch[0], use_half=use_half)
+
+ batch[1] = cat_and_cast(batch[1], use_half)
+
+ batch[4] = cat_and_cast(batch[4], False)
+
+ return batch
+
+ return build_batch
+
+
+def setup_coco_img_ids(coco_root_dir, coco_slice_name, coco_category_names, img_id):
+ annFile = '{}/annotations/instances_{}.json'.format(
+ coco_root_dir, coco_slice_name)
+
+ # initialize COCO api for instance annotations
+ coco = COCO(annFile)
+
+ # display COCO categories and supercategories
+ cats = coco.loadCats(coco.getCatIds())
+ cat_id_to_cat = {cat['id']: cat for cat in cats}
+ nms = [cat['name'] for cat in cats]
+ # print('COCO categories: \n{}\n'.format(' '.join(nms)))
+
+ # nms = set([cat['supercategory'] for cat in cats])
+ # print('COCO supercategories: \n{}'.format(' '.join(nms)))
+
+ if coco_category_names is not None:
+ catIds = coco.getCatIds(catNms=coco_category_names)
+ else:
+ catIds = coco.getCatIds()
+
+ if img_id is not None:
+ coco_img_ids = [img_id]
+ elif coco_category_names is None:
+ coco_img_ids = coco.getImgIds()
+ else:
+ coco_img_ids = coco.getImgIds(catIds=catIds)
+
+ return coco_img_ids, cat_id_to_cat, catIds, coco
diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py
new file mode 100644
index 0000000000..7671f9ce37
--- /dev/null
+++ b/scripts/sam/eval_combo.py
@@ -0,0 +1,406 @@
+import os
+import tqdm
+import torch
+import fire
+from metrics import calculate_miou, create_result_entry
+from data import build_data, setup_coco_img_ids
+import math
+import segment_anything_fast
+import time
+import resource
+
+torch._dynamo.config.cache_size_limit = 50000
+
+def unbind_jagged(device, data, sizes, offsets):
+ if data is None:
+ return None
+ data = data.to(device=device, non_blocking=True)
+ return [data[offsets[batch_idx]:offsets[batch_idx+1]].view(sizes[batch_idx]) for batch_idx in range(len(sizes))]
+
+PADDED_TENSOR=None
+
+# Preallocate a "landing" Tensor for incoming data and reuse it across launches.
+def pad_to_batch_size(batch, batch_size, device):
+ assert batch.dim() == 4
+ # assert batch.is_pinned()
+ global PADDED_TENSOR
+ if PADDED_TENSOR is None:
+ batch = batch.to(device=device, non_blocking=True)
+ full_batch_size = (batch_size, batch.size(1), batch.size(2), batch.size(3))
+ first_entry = batch[0].unsqueeze(0)
+ repeat_first_entry = first_entry.expand(full_batch_size)
+ padded_batch = torch.cat([batch, repeat_first_entry[batch.size(0):batch_size]], dim=0)
+ assert padded_batch.size() == full_batch_size
+ PADDED_TENSOR = padded_batch
+ PADDED_TENSOR[:batch.size(0)].copy_(batch, non_blocking=True)
+ return PADDED_TENSOR
+
+def get_features_batch(encoder, input_image_batch, pad_input_image_batch, batch_size, device):
+ if pad_input_image_batch:
+ features_batch = encoder(pad_to_batch_size(input_image_batch, batch_size, device))
+ return features_batch[:input_image_batch.size(0)]
+ return encoder(input_image_batch)
+
+def build_results_batch(predictor, batch, batch_size, pad_input_image_batch):
+ encoder = predictor.model.image_encoder
+ device = predictor.device
+
+ input_image_batch = batch[0]
+ # The number of valid data points varies slightly per batch
+ orig_input_image_batch_size = input_image_batch.size(0)
+ if input_image_batch is None:
+ return (None, None, None)
+
+ with torch.autograd.profiler.record_function("data transfer"):
+ coords_lists = unbind_jagged(*([device] + batch[1:4]))
+ gt_masks_lists = unbind_jagged(*([device] + batch[4:7]))
+ if coords_lists is None:
+ return (None, None, None)
+ datapoints = list(zip(*(batch[7:] + [coords_lists, gt_masks_lists])))
+ if pad_input_image_batch:
+ # Pad to a static shape to avoid recompilation
+ input_image_batch = pad_to_batch_size(input_image_batch, batch_size, device)
+ else:
+ input_image_batch = input_image_batch.to(device=device, non_blocking=True)
+
+ # We explicitly exclude data transfers from the timing to focus
+ # only on the kernel performance.
+ # Next we synchronize and set two events to start timing.
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+ start_event.record()
+ else:
+ t0 = time.time()
+
+ with torch.autograd.profiler.record_function("timed region"):
+ with torch.autograd.profiler.record_function("image encoder"):
+ features_batch = encoder(input_image_batch)
+ features_batch = features_batch[:orig_input_image_batch_size]
+
+ with torch.autograd.profiler.record_function("predict_torch"):
+ result_batch = []
+ for batch_idx, (anns, image, input_size, idx, coords, gt_masks) in enumerate(datapoints):
+ features = features_batch.narrow(0, batch_idx, 1)
+ predictor.reset_image()
+ predictor.original_size = image.shape[:2]
+ predictor.input_size = input_size
+ predictor.features = features
+ predictor.is_image_set = True
+ coords = coords.unsqueeze(1)
+ fg_labels = torch.ones(
+ (coords.size(0), 1), dtype=torch.int, device=device)
+ masks, scores, logits = predictor.predict_torch(
+ point_coords=coords,
+ point_labels=fg_labels,
+ multimask_output=True,
+ )
+ entry = create_result_entry(anns, gt_masks, masks, scores, idx)
+ result_batch += entry
+
+ # After all kernels have been launched we synchronize again and measure
+ # the amount of time spent on the GPU. This is a fairly tight measurement
+ # around the launched GPU kernels and excludes data movement from host
+ # to device.
+ if torch.cuda.is_available():
+ end_event.record()
+ torch.cuda.synchronize()
+ elapsed_time = start_event.elapsed_time(end_event)
+ else:
+ elapsed_time = time.time() - t0
+ return result_batch, orig_input_image_batch_size, elapsed_time
+
+
+def build_results(batched_data_iter,
+ predictor,
+ mask_debug_out_dir,
+ batch_size,
+ use_compile,
+ use_compile_decoder,
+ pad_input_image_batch,
+ compress,
+ use_fullgraph=False):
+
+ # TODO: Re-enable this for datapoints
+ assert not use_compile_decoder
+
+ batch_runner = build_results_batch
+
+ results = []
+ batch_idx = 0
+ num_images = 0
+ num_batches = 0
+ elapsed_time = 0
+ partial_batch = False
+ for batch in tqdm.tqdm(batched_data_iter):
+ with torch.no_grad():
+ if batch_idx == 0:
+ with torch.autograd.profiler.record_function("compilation and warmup"):
+ if str(use_compile) != "False":
+ predictor.model.image_encoder = torch.compile(predictor.model.image_encoder, mode=use_compile, fullgraph=use_fullgraph)
+ # Run first batch a few times for warmup and exclude it from the final timings
+ for _ in range(5):
+ _ = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
+ result_batch, num_datapoints, kernel_time = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
+ if result_batch is not None:
+ results += result_batch
+ # We expect a partial batch to only happens once at the end
+ assert not partial_batch
+ # Only measure timing on full batches
+ if num_datapoints == batch_size:
+ num_images += num_datapoints
+ num_batches += 1
+ # We consistently exclude the last (512 - filtered) images
+ # Since batch sizes must be powers of two and less than
+ # or equal 512 this ensures consistent timing across varying
+ # batch sizes.
+ if num_images <= 4488:
+ elapsed_time += kernel_time
+ else:
+ partial_batch = True
+ batch_idx += 1
+
+ avg_ms_per_img = None
+ if num_images > 0:
+ avg_ms_per_img = elapsed_time
+ avg_ms_per_img = avg_ms_per_img / num_images
+
+ return results, avg_ms_per_img, num_batches, num_images
+
+
+def identity_runner(fn, *args, **kwargs):
+ return fn(*args, **kwargs)
+
+
+def profiler_runner(path, fn, *args, **kwargs):
+ with torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA],
+ record_shapes=True) as prof:
+ result = fn(*args, **kwargs)
+ prof.export_chrome_trace(path)
+ return result
+
+
+def profile_top_runner(fn, *args, **kwargs):
+ with torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA],
+ record_shapes=True) as prof:
+ result = fn(*args, **kwargs)
+ if torch.cuda.is_available():
+ print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
+ else:
+ print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
+ return result
+
+
+def memory_runner(path, fn, *args, **kwargs):
+ print("Start memory recording")
+ torch.cuda.synchronize()
+ torch.cuda.memory._record_memory_history(True, trace_alloc_max_entries=100000, trace_alloc_record_context=True)
+ result = fn(*args, **kwargs)
+ torch.cuda.synchronize()
+ snapshot = torch.cuda.memory._snapshot()
+ print("Finish memory recording")
+ import pickle
+ with open(path, 'wb') as f:
+ pickle.dump(snapshot, f)
+ # Use to convert pickle file into html
+ # python torch/cuda/_memory_viz.py trace_plot .pickle -o .html
+ return result
+
+
+def run(
+ coco_root_dir,
+ coco_slice_name,
+ sam_checkpoint_base_path,
+ sam_model_type,
+ point_sampling_cache_dir,
+ mask_debug_out_dir,
+ batch_size=1,
+ print_header=False,
+ coco_category_names=None,
+ limit=None,
+ img_id=None,
+ use_half=None,
+ use_compile="False",
+ use_compile_decoder=False,
+ compress=None,
+ num_workers=0,
+ use_rel_pos=True,
+ pad_input_image_batch=True,
+ profile_path=None,
+ profile_top=False,
+ memory_path=None,
+ device="cuda"
+):
+ from torch._inductor import config as inductorconfig
+ inductorconfig.triton.unique_kernel_names = True
+ inductorconfig.epilogue_fusion = True
+ inductorconfig.coordinate_descent_tuning = True
+ inductorconfig.coordinate_descent_check_all_directions = True
+ inductorconfig.force_fuse_int_mm_with_mul = True
+ inductorconfig.use_mixed_mm = True
+ from torch.sparse import SparseSemiStructuredTensor
+ SparseSemiStructuredTensor._FORCE_CUTLASS = False
+
+ if use_half is not None:
+ if use_half == "float16":
+ use_half = torch.float16
+ elif use_half == "bfloat16":
+ use_half = torch.bfloat16
+ else:
+ raise ValueError("Expected one of float16 or bfloat for specified {use_half}")
+
+
+ # Batch size needs to be a multiple of two and at most 512.
+ assert math.log2(batch_size).is_integer()
+ assert batch_size <= 512
+
+ # https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints
+ # largest to smallest: vit_h, vit_l, vit_b
+ model_type_to_checkpoint = {
+ 'vit_h': f'{sam_checkpoint_base_path}/sam_vit_h_4b8939.pth',
+ 'vit_l': f'{sam_checkpoint_base_path}/sam_vit_l_0b3195.pth',
+ 'vit_b': f'{sam_checkpoint_base_path}/sam_vit_b_01ec64.pth',
+ }
+
+ from segment_anything_fast import sam_model_registry, SamPredictor
+ checkpoint_path = model_type_to_checkpoint[sam_model_type]
+ sam = sam_model_registry[sam_model_type](checkpoint=checkpoint_path).to(torch.device(device))
+ predictor = SamPredictor(sam)
+
+ from segment_anything_fast import tools
+ tools.apply_eval_dtype_predictor(predictor, use_half)
+
+ for block in predictor.model.image_encoder.blocks:
+ block.attn.use_rel_pos = use_rel_pos
+
+ if compress == "int8_dynamic_quant":
+ from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
+ from torchao.utils import unwrap_tensor_subclass
+ predictor.model.image_encoder = quantize(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
+ predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
+ elif compress == "sparse_mlp_only":
+ def mlp_only(mod, name):
+ return isinstance(mod, torch.nn.Linear) and 'mlp' in name
+ from torchao.sparsity import apply_sparse_semi_structured
+ apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only)
+ elif compress == "sparse":
+ from torchao.sparsity import apply_sparse_semi_structured
+ apply_sparse_semi_structured(predictor.model.image_encoder)
+ elif compress == "int8_dynamic_quant_sparse":
+ from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight
+ from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
+ from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
+ from torchao.utils import unwrap_tensor_subclass
+
+ def attn_only(mod, name):
+ return isinstance(mod, torch.nn.Linear) and 'attn' in name
+ def mlp_lin1_only(mod, name):
+ return isinstance(mod, torch.nn.Linear) and 'lin1' in name
+ def mlp_lin2_only(mod, name):
+ return isinstance(mod, torch.nn.Linear) and 'lin2' in name
+ def mlp_only(mod, name):
+ return isinstance(mod, torch.nn.Linear) and 'mlp' in name
+
+ apply_fake_sparsity(predictor.model.image_encoder,
+ filter_fn=mlp_only)
+
+ predictor.model.image_encoder = quantize(predictor.model.image_encoder,
+ int8_dynamic_activation_int8_weight(),
+ attn_only)
+ predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
+
+ predictor.model.image_encoder = quantize(predictor.model.image_encoder,
+ Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float,
+ mlp_lin1_only)
+ apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only)
+ else:
+ assert compress is None, f"Unsupported compress mode {compress}"
+
+
+ coco_img_ids_, cat_id_to_cat, catIds, coco = setup_coco_img_ids(
+ coco_root_dir, coco_slice_name, coco_category_names, img_id)
+
+ coco_img_ids = []
+ for imgId in coco_img_ids_:
+ img = coco.loadImgs(imgId)[0]
+ annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
+ anns = coco.loadAnns(annIds)
+ if len(anns) != 0:
+ coco_img_ids.append(imgId)
+
+ build_batch = build_data(coco_img_ids,
+ coco,
+ catIds,
+ coco_root_dir,
+ coco_slice_name,
+ point_sampling_cache_dir,
+ predictor,
+ use_half,
+ pad_input_image_batch)
+
+ limit = len(coco_img_ids) if limit is None else limit
+ batched_data_iter = torch.utils.data.DataLoader(list(range(limit)),
+ batch_size=batch_size,
+ collate_fn=build_batch,
+ num_workers=num_workers,
+ pin_memory=False)
+ runner = identity_runner
+
+ if profile_path is not None:
+ import functools
+ runner = functools.partial(profiler_runner, profile_path)
+
+ if profile_top:
+ runner = profile_top_runner
+
+ if memory_path is not None:
+ assert use_compile != "max-autotune", f"Memory path does not support {use_compile}"
+ import functools
+ runner = functools.partial(memory_runner, memory_path)
+
+ results, avg_ms_per_img, num_batches, num_images = runner(build_results,
+ batched_data_iter,
+ predictor,
+ mask_debug_out_dir,
+ batch_size,
+ use_compile,
+ use_compile_decoder,
+ pad_input_image_batch,
+ compress)
+
+ results = [[r[0], r[1], r[2], r[3].item()] for r in results]
+
+ img_s, batch_ms_batch_size = None, None
+ if avg_ms_per_img is not None:
+ img_s = 1000 / avg_ms_per_img
+ batch_ms_batch_size = (avg_ms_per_img * num_images) / num_batches / batch_size
+
+ mIoU = calculate_miou(results, mask_debug_out_dir, True, cat_id_to_cat)
+ if torch.cuda.is_available():
+ max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
+ _, total_memory = torch.cuda.mem_get_info()
+ max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
+ max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
+ else:
+ import psutil
+ total_memory = psutil.virtual_memory().total
+ max_memory_allocated_bytes = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / (total_memory >> 10)))
+ max_memory_allocated_bytes = max_memory_allocated_bytes >> 10
+
+ with open("results.csv", "a") as f:
+ if print_header:
+ header = ",".join(["device", "sam_model_type", "batch_size", "memory(MiB)", "memory(%)", "img_s(avg)", "batch_ms(avg)/batch_size", "mIoU", "use_compile",
+ "use_half", "compress", "use_compile_decoder", "use_rel_pos", "pad_input_image_batch", "num_workers", "num_batches", "num_images", "profile_path", "memory_path"])
+ f.write(header+"\n")
+ vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
+ use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
+ f.write(vals+"\n")
+
+if __name__ == '__main__':
+ fire.Fire(run)
diff --git a/scripts/sam/flash_4_configs.p b/scripts/sam/flash_4_configs.p
new file mode 100644
index 0000000000..4b6e234d0d
Binary files /dev/null and b/scripts/sam/flash_4_configs.p differ
diff --git a/scripts/sam/metrics.py b/scripts/sam/metrics.py
new file mode 100644
index 0000000000..7197d1e434
--- /dev/null
+++ b/scripts/sam/metrics.py
@@ -0,0 +1,61 @@
+import torch
+import pandas as pd
+
+def create_result_entry(anns, gt_masks_list, masks, scores, img_idx):
+ argmax_scores = torch.argmax(scores, dim=1)
+ inference_masks = masks.gather(1, argmax_scores.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(
+ (masks.size(0), 1, masks.size(2), masks.size(3)))).squeeze(1)
+
+ def _iou(mask1, mask2):
+ assert mask1.dim() == 3
+ assert mask2.dim() == 3
+ intersection = torch.logical_and(mask1, mask2)
+ union = torch.logical_or(mask1, mask2)
+ return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))
+
+ top_score_ious = _iou(inference_masks, gt_masks_list)
+
+ entry = []
+ for idx in range(top_score_ious.size(0)):
+ entry.append(
+ [img_idx, anns[idx]['id'], anns[idx]['category_id'], top_score_ious[idx]])
+ return entry
+
+
+def calculate_miou(results, mask_debug_out_dir, silent, cat_id_to_cat):
+ df = pd.DataFrame(results, columns=['img_id', 'ann_id', 'cat_id', 'iou'])
+ df.to_csv(f'{mask_debug_out_dir}/df.csv')
+ df['supercategory'] = df['cat_id'].map(
+ lambda cat_id: cat_id_to_cat[cat_id]['supercategory'])
+ df['category'] = df['cat_id'].map(
+ lambda cat_id: cat_id_to_cat[cat_id]['name'])
+
+ # TODO: cross reference the specifics of how we calculate mIoU with
+ # the SAM folks (should it be per dataset, per category, per image, etc)
+ # currently, just calculate them all
+
+ # TODO: QOL save the summaries to file
+
+ # per category
+ per_category = pd.pivot_table(
+ df, values='iou', index=['cat_id', 'supercategory', 'category'],
+ aggfunc=('mean', 'count'))
+ if not silent:
+ print('\nmIoU averaged per category')
+ print(per_category)
+
+ # per super-category
+ per_supercategory = pd.pivot_table(
+ df, values='iou', index=['supercategory'],
+ aggfunc=('mean', 'count'))
+ if not silent:
+ print('\nmIoU averaged per supercategory')
+ print(per_supercategory)
+
+ # per all selected masks
+ per_all_masks_agg = df['iou'].agg(['mean', 'count'])
+ if not silent:
+ print('\nmIoU averaged per all selected masks')
+ print(per_all_masks_agg)
+
+ return df['iou'].agg(['mean', 'count'])['mean']
diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv
new file mode 100644
index 0000000000..01aad5c022
--- /dev/null
+++ b/scripts/sam/results.csv
@@ -0,0 +1,6 @@
+device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
+cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
+cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
+cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
+cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
+cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
diff --git a/scripts/sam/setup.sh b/scripts/sam/setup.sh
new file mode 100644
index 0000000000..8c730ea9ff
--- /dev/null
+++ b/scripts/sam/setup.sh
@@ -0,0 +1,22 @@
+
+SETUP_HOME=$(pwd)
+
+
+mkdir -p checkpoints
+mkdir -p datasets
+
+mkdir -p tmp
+mkdir -p tmp/sam_coco_mask_center_cache
+mkdir -p tmp/sam_eval_masks_out
+
+wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
+wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
+wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
+
+mkdir -p datasets/coco2017
+wget -nc -P datasets/coco2017 http://images.cocodataset.org/zips/val2017.zip
+wget -nc -P datasets/coco2017 http://images.cocodataset.org/annotations/annotations_trainval2017.zip
+
+cd datasets/coco2017 && unzip -n val2017.zip && cd $SETUP_HOME
+cd datasets/coco2017 && unzip -n annotations_trainval2017.zip && cd $SETUP_HOME
+
diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md
index bc0a61b202..239d52bb0e 100644
--- a/torchao/sparsity/README.md
+++ b/torchao/sparsity/README.md
@@ -20,29 +20,32 @@ More concretely, we hope to provide tutorials and APIs for both sparse kernels (
## Success Stories
-#### segment-anything
+#### segment-anything-fast
We applied 2:4 sparsity to accelerate segment-anything, as part of [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast).
-The results mentioned in the README of the repo compose sparsity with a suite of other inference acceleration techniques.
-From our [benchmarking](https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_sam.py), we see a 1.1x speedup when running with `SEGMENT_ANYTHING_FAST_USE_FLASH_4` enabled.
-To reproduce these benchmarks you can run the following command:
+We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
-The inference acceleration of semi-structured sparsity depends on the matmul shapes, which is why we don't see additional speedups when applying to all linear layers (attn + mlp) of segment-anything.
-We find that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`). To repoduce our benchmarks you can run the following command:
+Overall, we found that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`), while mitigating accuracy loss.
-```
-python benchmarks/benchmark_sam.py
-```
+Applying sparsity to the attention linear layers led to a slower model, likely due to two reasons:
+- We cannot fuse into our semi-structured sparse matmul with torch.compile.
+- The speedups we observe for sparse matmul depend on the matmul shapes, and the attention matmuls are smaller than the MLP ones.
+
+We were also are able to compose int8 dynamic quantization with 2:4 sparsity for futher speedups.
+
+We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
+
+The following benchmarks we ran for sam ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:
-The following benchmarks we run on an A100, with batch_size=32 and `bfloat16` dtype:
+| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
+|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
+| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | |
+| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** |
+| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** |
+| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** |
+| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** |
-| qkv | proj | lin1 | lin2 | time | memory | img/s |
-| ---- | ---- | ---- | ---- | ---- | ------ | ----- |
-| None | None | None | None | 1361.73 | 15.81 | 23.50 |
-| None | None | sparse (cusparselt) | sparse (cusparselt) | 1245.15 | 15.46 | 25.70 |
-| None | None | sparse (cutlass) | sparse (cutlass) | 1251.047651 | 15.41 | 25.59 |
-| sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | 1265.43 | 12.71 | 25.29|
-| sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | 1274.96 | 12.70 | 25.10 |
+To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md).
#### BERT
diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py
index 6d7a856ea4..2601f166a8 100644
--- a/torchao/sparsity/prototype/dynamic_quant_sparse.py
+++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py
@@ -149,6 +149,17 @@ def sparse_quant_int8_cutlass_matmul(
class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight(
Int8DynamicallyQuantizedLinearWeight
):
+ def dequantize(self, dtype=None):
+ # overload dequantize op for __repr__
+ zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype)
+ int_data_expanded = torch._cslt_sparse_mm(self.int_data, torch.eye(self.shape[1],
+ dtype=self.int_data.dtype,
+ device=self.int_data.device))
+ dq_t = dequantize_per_channel(
+ int_data_expanded, self.q_scales, zero_points, self.dtype if dtype is None else dtype
+ ).to(self.dtype)
+
+ return dq_t if not self.transposed else dq_t.t()
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
@@ -158,7 +169,7 @@ def _quantized_op(act_mat, w_qtensor, bias):
)
@classmethod
- def from_float(cls, input_float, qmin=-8, qmax=7):
+ def from_float(cls, input_float, qmin=-128, qmax=127):
assert input_float.is_cuda
diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py
index 90e35a4121..d8ec14a266 100644
--- a/torchao/sparsity/sparse_api.py
+++ b/torchao/sparsity/sparse_api.py
@@ -4,15 +4,16 @@
from torchao.quantization.quant_api import _is_linear
# Sparsity helper functions
-def apply_fake_sparsity(model):
+def apply_fake_sparsity(model, **kwargs):
"""
This function simulates 2:4 sparsity on all linear layers in a model.
It uses the torch.ao.pruning flow.
"""
+ filter_fn = kwargs.pop("filter_fn", _is_linear)
# torch.ao.pruning flow
sparse_config = []
for name, mod in model.named_modules():
- if isinstance(mod, torch.nn.Linear):
+ if filter_fn(mod, name):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier = WeightNormSparsifier(
@@ -26,7 +27,7 @@ def apply_fake_sparsity(model):
def apply_sparse_semi_structured(model, **kwargs):
filter_fn = kwargs.pop("filter_fn", _is_linear)
- apply_fake_sparsity(model)
+ apply_fake_sparsity(model, filter_fn=filter_fn)
for name, mod in model.named_modules():
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))