diff --git a/README.md b/README.md index e99f565900..f0cb7e2c00 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,20 @@ And a quick crash course on inference quantization to help parse the above table In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes. -* 8x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) +* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything). +* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`. + +| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy | +|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------| +| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline | +| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** | +| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** | +| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** | +| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** | +| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** | + +The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32 + * 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2) * 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py deleted file mode 100644 index 3a5d817602..0000000000 --- a/benchmarks/benchmark_sam.py +++ /dev/null @@ -1,137 +0,0 @@ -import argparse -from itertools import product - -import pandas as pd -# to install segment-anything-fast you can run: -# pip install git+https://github.com/pytorch-labs/segment-anything-fast.git -from segment_anything_fast import sam_model_registry -import torch -from torch.utils.benchmark import Timer -from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT -from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, - _get_subclass_inserter, - _is_linear, - QuantizedLinearWeightBase, - Int8DynamicallyQuantizedLinearWeight, -) -from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.sparsity import ( - apply_sparse_semi_structured, - apply_fake_sparsity, -) -from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight -from tqdm import tqdm - -sam_checkpoint_base_path = "/home/jessecai/local/MODELS" -model_type = 'vit_h' -model_name = 'sam_vit_h_4b8939.pth' -checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" - -torch._inductor.config.epilogue_fusion = True -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.coordinate_descent_check_all_directions = True -torch._inductor.config.force_fuse_int_mm_with_mul = True - -@torch.no_grad() -def benchmark(f, *args, **kwargs): - for _ in range(3): - f(*args, **kwargs) - torch.cuda.synchronize() - - torch.cuda.reset_peak_memory_stats() - t0 = Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) - return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} - -def get_sam_model(only_one_block=False, batchsize=1): - sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda() - model = sam.image_encoder.eval() - image = torch.randn(batchsize, 3, 1024, 1024, device='cuda') - - # code to use just a single block of the model - if only_one_block: - model = model.blocks[0] - image = torch.randn(batchsize, 64, 64, 1280, device='cuda') - return model, image - -def qkv_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'qkv' in name - -def proj_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'proj' in name - -def lin1_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'lin1' in name - -def lin2_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'lin2' in name - -SUBCLASSES = { - "quant" : Int8DynamicallyQuantizedLinearWeight, - "quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, - "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, - "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, - "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, -} - -def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): - res = { - "block_only": block_only, - "batchsize": batchsize, - "dtype": dtype, - "compile": compile, - "qkv" : qkv, - "proj": proj, - "lin1": lin1, - "lin2": lin2, - } - with torch.no_grad(): - model, image = get_sam_model(block_only, batchsize) - model = model.to(dtype) - image = image.to(dtype) - - # 2:4 prune model - apply_fake_sparsity(model) - option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only]) - - for option, filter_fn in option_and_filter_fn: - subclass = SUBCLASSES.get(option, None) - if subclass and issubclass(subclass, SparseSemiStructuredTensor): - # replace with to_sparse_semi_structured - for name, mod in model.named_modules(): - if filter_fn(mod, name): - mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) - elif subclass and issubclass(subclass, QuantizedLinearWeightBase): - _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) - - if compile: - model = torch.compile(model, mode='max-autotune') - - res.update(benchmark(model, image)) - res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize']) - return res - -if __name__ == "__main__": - print("BENCHMARKING") - parser = argparse.ArgumentParser(description='Process some integers.') - parser.add_argument('--eager', action='store_true', help='enable/disable torch.compile') - args = parser.parse_args() - # ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")] - ALL_RUNS = [ - run_once(compile=not args.eager), - run_once(compile=not args.eager, lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), - run_once(compile=not args.eager, lin1="sparse (cutlass)", lin2="sparse (cutlass)"), - run_once(compile=not args.eager, qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), - run_once(compile=not args.eager, qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"), - # run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"), - # run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"), - # run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), - # run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"), - # run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), - ] - df = pd.DataFrame(ALL_RUNS) - df.to_csv("sam_benchmark_results.csv") - print(df) diff --git a/scripts/sam/.gitignore b/scripts/sam/.gitignore new file mode 100644 index 0000000000..3e7cc59d91 --- /dev/null +++ b/scripts/sam/.gitignore @@ -0,0 +1,3 @@ +tmp +checkpoints +datasets diff --git a/scripts/sam/README.md b/scripts/sam/README.md new file mode 100644 index 0000000000..426d7fe6a8 --- /dev/null +++ b/scripts/sam/README.md @@ -0,0 +1,21 @@ +# benchmarking instructions: + +Setup your enviornment with: +``` +conda env create -n "saf-ao" python=3.10 +conda activate saf-ao +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 +pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git +pip3 install tqdm fire pandas +cd ../.. && python setup.py install +``` + +Then download data and models by running +``` +sh setup.sh +``` + +Finally, you can run benchmarks with +``` +sh benchmark_sam.sh +``` diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh new file mode 100755 index 0000000000..5c1262f9cc --- /dev/null +++ b/scripts/sam/benchmark.sh @@ -0,0 +1,11 @@ +# baseline +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True +# int8 dynamic quant (all) +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant +# 2:4 sparsity (all) +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only +# 2:4 sparsity (mlp only) +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse +# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse + diff --git a/scripts/sam/data.py b/scripts/sam/data.py new file mode 100644 index 0000000000..20f7632652 --- /dev/null +++ b/scripts/sam/data.py @@ -0,0 +1,297 @@ +import torch +import diskcache +from pycocotools.coco import COCO +import numpy as np +from scipy import ndimage +import skimage.io as io +import skimage.color as color + + +def _get_center_point(mask, ann_id, cache): + """ + This is a rudimentary version of https://arxiv.org/pdf/2304.02643.pdf, + section D.1.Point Sampling + + From the paper: "The first point is chosen deterministically as the point + farthest from the object boundary." + + The code below is an approximation of this. + + First, we try to calculate the center of mass. If it's inside the mask, we + stop here. + + The centroid may be outside of the mask for some mask shapes. In this case + we do a slow hack, specifically, we check for the + minumum of the maximum distance from the boundary in four directions + (up, right, down, left), and take the point with the maximum of these + minimums. Note: this is not performant for large masks. + + Returns the center point in (x, y) format + """ + if ann_id in cache: + return cache[ann_id] + + # try the center of mass, keep it if it's inside the mask + com_y, com_x = ndimage.center_of_mass(mask) + com_y, com_x = int(round(com_y, 0)), int(round(com_x, 0)) + if mask[com_y][com_x]: + cache[ann_id] = (com_x, com_y) + return (com_x, com_y) + + # if center of mass didn't work, do the slow manual approximation + + # up, right, down, left + # TODO(future): approximate better by adding more directions + distances_to_check_deg = [0, 90, 180, 270] + + global_min_max_distance = float('-inf') + global_coords = None + # For now, terminate early to speed up the calculation as long as + # the point sample is gooe enough. This sacrifices the quality of point + # sampling for speed. In the future we can make this more accurate. + DISTANCE_GOOD_ENOUGH_THRESHOLD = 20 + + # Note: precalculating the bounding box could be somewhat + # helpful, but checked the performance gain and it's not much + # so leaving it out to keep the code simple. + # Note: tried binary search instead of incrementing by one to + # travel up/right/left/down, but that does not handle masks + # with all shapes properly (there could be multiple boundaries). + for row_idx in range(mask.shape[0]): + for col_idx in range(mask.shape[1]): + cur_point = mask[row_idx, col_idx] + + # skip points inside bounding box but outside mask + if not cur_point: + continue + + max_distances = [] + for direction in distances_to_check_deg: + # TODO(future) binary search instead of brute forcing it if we + # need a speedup, with the cache it doesn't really matter though + if direction == 0: + # UP + cur_row_idx = row_idx + + while cur_row_idx >= 0 and mask[cur_row_idx, col_idx]: + cur_row_idx = cur_row_idx - 1 + cur_row_idx += 1 + distance = row_idx - cur_row_idx + max_distances.append(distance) + + elif direction == 90: + # RIGHT + cur_col_idx = col_idx + + while cur_col_idx <= mask.shape[1] - 1 and \ + mask[row_idx, cur_col_idx]: + cur_col_idx += 1 + cur_col_idx -= 1 + distance = cur_col_idx - col_idx + max_distances.append(distance) + + elif direction == 180: + # DOWN + cur_row_idx = row_idx + while cur_row_idx <= mask.shape[0] - 1 and \ + mask[cur_row_idx, col_idx]: + cur_row_idx = cur_row_idx + 1 + cur_row_idx -= 1 + distance = cur_row_idx - row_idx + max_distances.append(distance) + + elif direction == 270: + # LEFT + cur_col_idx = col_idx + while cur_col_idx >= 0 and mask[row_idx, cur_col_idx]: + cur_col_idx -= 1 + cur_col_idx += 1 + distance = col_idx - cur_col_idx + max_distances.append(distance) + + min_max_distance = min(max_distances) + if min_max_distance > global_min_max_distance: + global_min_max_distance = min_max_distance + global_coords = (col_idx, row_idx) + if global_min_max_distance >= DISTANCE_GOOD_ENOUGH_THRESHOLD: + break + + cache[ann_id] = global_coords + return global_coords + + +def build_datapoint(imgId, + coco, + pixel_mean, + pixel_std, + coco_root_dir, + coco_slice_name, + catIds, + cache, + predictor, + pad_input_image_batch): + img = coco.loadImgs(imgId)[0] + + file_location = f'{coco_root_dir}/{coco_slice_name}/{img["file_name"]}' + I = io.imread(file_location) + if len(I.shape) == 2: + # some images, like img_id==61418, are grayscale + # convert to RGB to ensure the rest of the pipeline works + I = color.gray2rgb(I) + + # load and display instance annotations + annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None) + anns = coco.loadAnns(annIds) + + # approximate the center point of each mask + coords_list = [] + gt_masks_list = [] + for ann in anns: + ann_id = ann['id'] + mask = coco.annToMask(ann) + gt_masks_list.append(torch.tensor(mask)) + coords = _get_center_point(mask, ann_id, cache) + coords_list.append(coords) + + image = I + + # predictor_set_image begin + # Transform the image to the form expected by the model + input_image = predictor.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image) + input_image_torch = input_image_torch.permute( + 2, 0, 1).contiguous()[None, :, :, :] + predictor_input_size = input_image_torch.shape[-2:] + + # Preprocess + x = input_image_torch + # Normalize colors + x = (x - pixel_mean) / pixel_std + + if pad_input_image_batch: + # Pad + h, w = x.shape[-2:] + padh = predictor.model.image_encoder.img_size - h + padw = predictor.model.image_encoder.img_size - w + x = torch.nn.functional.pad(x, (0, padw, 0, padh)) + else: + x = x.squeeze(0) + + gt_masks_list = torch.stack(gt_masks_list) if len(gt_masks_list) else None + return image, coords_list, gt_masks_list, anns, x, predictor_input_size + + +def build_data(coco_img_ids, + coco, + catIds, + coco_root_dir, + coco_slice_name, + point_sampling_cache_dir, + predictor, + use_half, + pad_input_image_batch): + cache = diskcache.Cache(point_sampling_cache_dir) + # make sure you clear the cache if you change the point sampling algorithm + # cache.clear() + + pixel_mean = predictor.model.pixel_mean.cpu() + pixel_std = predictor.model.pixel_std.cpu() + + def build_batch(indicies): + batch = [[], [], [], [], [], [], [], [], [], [], []] + batch[3] = [0] + batch[6] = [0] + for img_idx in indicies: + imgId = coco_img_ids[img_idx] + + datapoint = build_datapoint(imgId, + coco, + pixel_mean, + pixel_std, + coco_root_dir, + coco_slice_name, + catIds, + cache, + predictor, + pad_input_image_batch) + I, coords_list, gt_masks_list, anns, x, predictor_input_size = datapoint + if len(coords_list) == 0: + continue + batch[0].append(x) + # batch[0].append(x[0]) + coords_list = predictor.transform.apply_coords( + np.array(coords_list), I.shape[:2]) + coords_list = torch.tensor(coords_list, dtype=torch.float) + + batch[1].append(coords_list.reshape(-1)) + batch[2].append(coords_list.size()) + batch[3].append(coords_list.numel() + batch[3][-1]) + + batch[4].append(gt_masks_list.reshape(-1)) + batch[5].append(gt_masks_list.size()) + batch[6].append(gt_masks_list.numel() + batch[6][-1]) + + batch[7].append(anns) + batch[8].append(I) + batch[9].append(predictor_input_size) + batch[10].append(img_idx) + + def cat_and_cast(b, use_half): + b = torch.cat(b) if len(b) > 0 else None + if use_half is not None and b is not None: + return b.to(use_half) + return b + + def to_nested_tensor(data, sizes=None, use_half=None): + if len(data) == 0: + return None + dtype = use_half if use_half is not None else torch.float32 + + if sizes is not None: + data = [d.view(s) for (d, s) in zip(data, sizes)] + + return torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged) + + if pad_input_image_batch: + batch[0] = cat_and_cast(batch[0], use_half) + else: + batch[0] = to_nested_tensor(batch[0], use_half=use_half) + + batch[1] = cat_and_cast(batch[1], use_half) + + batch[4] = cat_and_cast(batch[4], False) + + return batch + + return build_batch + + +def setup_coco_img_ids(coco_root_dir, coco_slice_name, coco_category_names, img_id): + annFile = '{}/annotations/instances_{}.json'.format( + coco_root_dir, coco_slice_name) + + # initialize COCO api for instance annotations + coco = COCO(annFile) + + # display COCO categories and supercategories + cats = coco.loadCats(coco.getCatIds()) + cat_id_to_cat = {cat['id']: cat for cat in cats} + nms = [cat['name'] for cat in cats] + # print('COCO categories: \n{}\n'.format(' '.join(nms))) + + # nms = set([cat['supercategory'] for cat in cats]) + # print('COCO supercategories: \n{}'.format(' '.join(nms))) + + if coco_category_names is not None: + catIds = coco.getCatIds(catNms=coco_category_names) + else: + catIds = coco.getCatIds() + + if img_id is not None: + coco_img_ids = [img_id] + elif coco_category_names is None: + coco_img_ids = coco.getImgIds() + else: + coco_img_ids = coco.getImgIds(catIds=catIds) + + return coco_img_ids, cat_id_to_cat, catIds, coco diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py new file mode 100644 index 0000000000..7671f9ce37 --- /dev/null +++ b/scripts/sam/eval_combo.py @@ -0,0 +1,406 @@ +import os +import tqdm +import torch +import fire +from metrics import calculate_miou, create_result_entry +from data import build_data, setup_coco_img_ids +import math +import segment_anything_fast +import time +import resource + +torch._dynamo.config.cache_size_limit = 50000 + +def unbind_jagged(device, data, sizes, offsets): + if data is None: + return None + data = data.to(device=device, non_blocking=True) + return [data[offsets[batch_idx]:offsets[batch_idx+1]].view(sizes[batch_idx]) for batch_idx in range(len(sizes))] + +PADDED_TENSOR=None + +# Preallocate a "landing" Tensor for incoming data and reuse it across launches. +def pad_to_batch_size(batch, batch_size, device): + assert batch.dim() == 4 + # assert batch.is_pinned() + global PADDED_TENSOR + if PADDED_TENSOR is None: + batch = batch.to(device=device, non_blocking=True) + full_batch_size = (batch_size, batch.size(1), batch.size(2), batch.size(3)) + first_entry = batch[0].unsqueeze(0) + repeat_first_entry = first_entry.expand(full_batch_size) + padded_batch = torch.cat([batch, repeat_first_entry[batch.size(0):batch_size]], dim=0) + assert padded_batch.size() == full_batch_size + PADDED_TENSOR = padded_batch + PADDED_TENSOR[:batch.size(0)].copy_(batch, non_blocking=True) + return PADDED_TENSOR + +def get_features_batch(encoder, input_image_batch, pad_input_image_batch, batch_size, device): + if pad_input_image_batch: + features_batch = encoder(pad_to_batch_size(input_image_batch, batch_size, device)) + return features_batch[:input_image_batch.size(0)] + return encoder(input_image_batch) + +def build_results_batch(predictor, batch, batch_size, pad_input_image_batch): + encoder = predictor.model.image_encoder + device = predictor.device + + input_image_batch = batch[0] + # The number of valid data points varies slightly per batch + orig_input_image_batch_size = input_image_batch.size(0) + if input_image_batch is None: + return (None, None, None) + + with torch.autograd.profiler.record_function("data transfer"): + coords_lists = unbind_jagged(*([device] + batch[1:4])) + gt_masks_lists = unbind_jagged(*([device] + batch[4:7])) + if coords_lists is None: + return (None, None, None) + datapoints = list(zip(*(batch[7:] + [coords_lists, gt_masks_lists]))) + if pad_input_image_batch: + # Pad to a static shape to avoid recompilation + input_image_batch = pad_to_batch_size(input_image_batch, batch_size, device) + else: + input_image_batch = input_image_batch.to(device=device, non_blocking=True) + + # We explicitly exclude data transfers from the timing to focus + # only on the kernel performance. + # Next we synchronize and set two events to start timing. + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + else: + t0 = time.time() + + with torch.autograd.profiler.record_function("timed region"): + with torch.autograd.profiler.record_function("image encoder"): + features_batch = encoder(input_image_batch) + features_batch = features_batch[:orig_input_image_batch_size] + + with torch.autograd.profiler.record_function("predict_torch"): + result_batch = [] + for batch_idx, (anns, image, input_size, idx, coords, gt_masks) in enumerate(datapoints): + features = features_batch.narrow(0, batch_idx, 1) + predictor.reset_image() + predictor.original_size = image.shape[:2] + predictor.input_size = input_size + predictor.features = features + predictor.is_image_set = True + coords = coords.unsqueeze(1) + fg_labels = torch.ones( + (coords.size(0), 1), dtype=torch.int, device=device) + masks, scores, logits = predictor.predict_torch( + point_coords=coords, + point_labels=fg_labels, + multimask_output=True, + ) + entry = create_result_entry(anns, gt_masks, masks, scores, idx) + result_batch += entry + + # After all kernels have been launched we synchronize again and measure + # the amount of time spent on the GPU. This is a fairly tight measurement + # around the launched GPU kernels and excludes data movement from host + # to device. + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + elapsed_time = start_event.elapsed_time(end_event) + else: + elapsed_time = time.time() - t0 + return result_batch, orig_input_image_batch_size, elapsed_time + + +def build_results(batched_data_iter, + predictor, + mask_debug_out_dir, + batch_size, + use_compile, + use_compile_decoder, + pad_input_image_batch, + compress, + use_fullgraph=False): + + # TODO: Re-enable this for datapoints + assert not use_compile_decoder + + batch_runner = build_results_batch + + results = [] + batch_idx = 0 + num_images = 0 + num_batches = 0 + elapsed_time = 0 + partial_batch = False + for batch in tqdm.tqdm(batched_data_iter): + with torch.no_grad(): + if batch_idx == 0: + with torch.autograd.profiler.record_function("compilation and warmup"): + if str(use_compile) != "False": + predictor.model.image_encoder = torch.compile(predictor.model.image_encoder, mode=use_compile, fullgraph=use_fullgraph) + # Run first batch a few times for warmup and exclude it from the final timings + for _ in range(5): + _ = batch_runner(predictor, batch, batch_size, pad_input_image_batch) + result_batch, num_datapoints, kernel_time = batch_runner(predictor, batch, batch_size, pad_input_image_batch) + if result_batch is not None: + results += result_batch + # We expect a partial batch to only happens once at the end + assert not partial_batch + # Only measure timing on full batches + if num_datapoints == batch_size: + num_images += num_datapoints + num_batches += 1 + # We consistently exclude the last (512 - filtered) images + # Since batch sizes must be powers of two and less than + # or equal 512 this ensures consistent timing across varying + # batch sizes. + if num_images <= 4488: + elapsed_time += kernel_time + else: + partial_batch = True + batch_idx += 1 + + avg_ms_per_img = None + if num_images > 0: + avg_ms_per_img = elapsed_time + avg_ms_per_img = avg_ms_per_img / num_images + + return results, avg_ms_per_img, num_batches, num_images + + +def identity_runner(fn, *args, **kwargs): + return fn(*args, **kwargs) + + +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + return result + + +def profile_top_runner(fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + if torch.cuda.is_available(): + print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + else: + print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) + return result + + +def memory_runner(path, fn, *args, **kwargs): + print("Start memory recording") + torch.cuda.synchronize() + torch.cuda.memory._record_memory_history(True, trace_alloc_max_entries=100000, trace_alloc_record_context=True) + result = fn(*args, **kwargs) + torch.cuda.synchronize() + snapshot = torch.cuda.memory._snapshot() + print("Finish memory recording") + import pickle + with open(path, 'wb') as f: + pickle.dump(snapshot, f) + # Use to convert pickle file into html + # python torch/cuda/_memory_viz.py trace_plot .pickle -o .html + return result + + +def run( + coco_root_dir, + coco_slice_name, + sam_checkpoint_base_path, + sam_model_type, + point_sampling_cache_dir, + mask_debug_out_dir, + batch_size=1, + print_header=False, + coco_category_names=None, + limit=None, + img_id=None, + use_half=None, + use_compile="False", + use_compile_decoder=False, + compress=None, + num_workers=0, + use_rel_pos=True, + pad_input_image_batch=True, + profile_path=None, + profile_top=False, + memory_path=None, + device="cuda" +): + from torch._inductor import config as inductorconfig + inductorconfig.triton.unique_kernel_names = True + inductorconfig.epilogue_fusion = True + inductorconfig.coordinate_descent_tuning = True + inductorconfig.coordinate_descent_check_all_directions = True + inductorconfig.force_fuse_int_mm_with_mul = True + inductorconfig.use_mixed_mm = True + from torch.sparse import SparseSemiStructuredTensor + SparseSemiStructuredTensor._FORCE_CUTLASS = False + + if use_half is not None: + if use_half == "float16": + use_half = torch.float16 + elif use_half == "bfloat16": + use_half = torch.bfloat16 + else: + raise ValueError("Expected one of float16 or bfloat for specified {use_half}") + + + # Batch size needs to be a multiple of two and at most 512. + assert math.log2(batch_size).is_integer() + assert batch_size <= 512 + + # https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints + # largest to smallest: vit_h, vit_l, vit_b + model_type_to_checkpoint = { + 'vit_h': f'{sam_checkpoint_base_path}/sam_vit_h_4b8939.pth', + 'vit_l': f'{sam_checkpoint_base_path}/sam_vit_l_0b3195.pth', + 'vit_b': f'{sam_checkpoint_base_path}/sam_vit_b_01ec64.pth', + } + + from segment_anything_fast import sam_model_registry, SamPredictor + checkpoint_path = model_type_to_checkpoint[sam_model_type] + sam = sam_model_registry[sam_model_type](checkpoint=checkpoint_path).to(torch.device(device)) + predictor = SamPredictor(sam) + + from segment_anything_fast import tools + tools.apply_eval_dtype_predictor(predictor, use_half) + + for block in predictor.model.image_encoder.blocks: + block.attn.use_rel_pos = use_rel_pos + + if compress == "int8_dynamic_quant": + from torchao.quantization import quantize, int8_dynamic_activation_int8_weight + from torchao.utils import unwrap_tensor_subclass + predictor.model.image_encoder = quantize(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) + predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) + elif compress == "sparse_mlp_only": + def mlp_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'mlp' in name + from torchao.sparsity import apply_sparse_semi_structured + apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only) + elif compress == "sparse": + from torchao.sparsity import apply_sparse_semi_structured + apply_sparse_semi_structured(predictor.model.image_encoder) + elif compress == "int8_dynamic_quant_sparse": + from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight + from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured + from torchao.quantization import quantize, int8_dynamic_activation_int8_weight + from torchao.utils import unwrap_tensor_subclass + + def attn_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'attn' in name + def mlp_lin1_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin1' in name + def mlp_lin2_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin2' in name + def mlp_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'mlp' in name + + apply_fake_sparsity(predictor.model.image_encoder, + filter_fn=mlp_only) + + predictor.model.image_encoder = quantize(predictor.model.image_encoder, + int8_dynamic_activation_int8_weight(), + attn_only) + predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) + + predictor.model.image_encoder = quantize(predictor.model.image_encoder, + Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float, + mlp_lin1_only) + apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only) + else: + assert compress is None, f"Unsupported compress mode {compress}" + + + coco_img_ids_, cat_id_to_cat, catIds, coco = setup_coco_img_ids( + coco_root_dir, coco_slice_name, coco_category_names, img_id) + + coco_img_ids = [] + for imgId in coco_img_ids_: + img = coco.loadImgs(imgId)[0] + annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None) + anns = coco.loadAnns(annIds) + if len(anns) != 0: + coco_img_ids.append(imgId) + + build_batch = build_data(coco_img_ids, + coco, + catIds, + coco_root_dir, + coco_slice_name, + point_sampling_cache_dir, + predictor, + use_half, + pad_input_image_batch) + + limit = len(coco_img_ids) if limit is None else limit + batched_data_iter = torch.utils.data.DataLoader(list(range(limit)), + batch_size=batch_size, + collate_fn=build_batch, + num_workers=num_workers, + pin_memory=False) + runner = identity_runner + + if profile_path is not None: + import functools + runner = functools.partial(profiler_runner, profile_path) + + if profile_top: + runner = profile_top_runner + + if memory_path is not None: + assert use_compile != "max-autotune", f"Memory path does not support {use_compile}" + import functools + runner = functools.partial(memory_runner, memory_path) + + results, avg_ms_per_img, num_batches, num_images = runner(build_results, + batched_data_iter, + predictor, + mask_debug_out_dir, + batch_size, + use_compile, + use_compile_decoder, + pad_input_image_batch, + compress) + + results = [[r[0], r[1], r[2], r[3].item()] for r in results] + + img_s, batch_ms_batch_size = None, None + if avg_ms_per_img is not None: + img_s = 1000 / avg_ms_per_img + batch_ms_batch_size = (avg_ms_per_img * num_images) / num_batches / batch_size + + mIoU = calculate_miou(results, mask_debug_out_dir, True, cat_id_to_cat) + if torch.cuda.is_available(): + max_memory_allocated_bytes = torch.cuda.max_memory_allocated() + _, total_memory = torch.cuda.mem_get_info() + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 + else: + import psutil + total_memory = psutil.virtual_memory().total + max_memory_allocated_bytes = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / (total_memory >> 10))) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 10 + + with open("results.csv", "a") as f: + if print_header: + header = ",".join(["device", "sam_model_type", "batch_size", "memory(MiB)", "memory(%)", "img_s(avg)", "batch_ms(avg)/batch_size", "mIoU", "use_compile", + "use_half", "compress", "use_compile_decoder", "use_rel_pos", "pad_input_image_batch", "num_workers", "num_batches", "num_images", "profile_path", "memory_path"]) + f.write(header+"\n") + vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile, + use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path])) + f.write(vals+"\n") + +if __name__ == '__main__': + fire.Fire(run) diff --git a/scripts/sam/flash_4_configs.p b/scripts/sam/flash_4_configs.p new file mode 100644 index 0000000000..4b6e234d0d Binary files /dev/null and b/scripts/sam/flash_4_configs.p differ diff --git a/scripts/sam/metrics.py b/scripts/sam/metrics.py new file mode 100644 index 0000000000..7197d1e434 --- /dev/null +++ b/scripts/sam/metrics.py @@ -0,0 +1,61 @@ +import torch +import pandas as pd + +def create_result_entry(anns, gt_masks_list, masks, scores, img_idx): + argmax_scores = torch.argmax(scores, dim=1) + inference_masks = masks.gather(1, argmax_scores.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand( + (masks.size(0), 1, masks.size(2), masks.size(3)))).squeeze(1) + + def _iou(mask1, mask2): + assert mask1.dim() == 3 + assert mask2.dim() == 3 + intersection = torch.logical_and(mask1, mask2) + union = torch.logical_or(mask1, mask2) + return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2))) + + top_score_ious = _iou(inference_masks, gt_masks_list) + + entry = [] + for idx in range(top_score_ious.size(0)): + entry.append( + [img_idx, anns[idx]['id'], anns[idx]['category_id'], top_score_ious[idx]]) + return entry + + +def calculate_miou(results, mask_debug_out_dir, silent, cat_id_to_cat): + df = pd.DataFrame(results, columns=['img_id', 'ann_id', 'cat_id', 'iou']) + df.to_csv(f'{mask_debug_out_dir}/df.csv') + df['supercategory'] = df['cat_id'].map( + lambda cat_id: cat_id_to_cat[cat_id]['supercategory']) + df['category'] = df['cat_id'].map( + lambda cat_id: cat_id_to_cat[cat_id]['name']) + + # TODO: cross reference the specifics of how we calculate mIoU with + # the SAM folks (should it be per dataset, per category, per image, etc) + # currently, just calculate them all + + # TODO: QOL save the summaries to file + + # per category + per_category = pd.pivot_table( + df, values='iou', index=['cat_id', 'supercategory', 'category'], + aggfunc=('mean', 'count')) + if not silent: + print('\nmIoU averaged per category') + print(per_category) + + # per super-category + per_supercategory = pd.pivot_table( + df, values='iou', index=['supercategory'], + aggfunc=('mean', 'count')) + if not silent: + print('\nmIoU averaged per supercategory') + print(per_supercategory) + + # per all selected masks + per_all_masks_agg = df['iou'].agg(['mean', 'count']) + if not silent: + print('\nmIoU averaged per all selected masks') + print(per_all_masks_agg) + + return df['iou'].agg(['mean', 'count'])['mean'] diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv new file mode 100644 index 0000000000..01aad5c022 --- /dev/null +++ b/scripts/sam/results.csv @@ -0,0 +1,6 @@ +device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path +cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None +cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/scripts/sam/setup.sh b/scripts/sam/setup.sh new file mode 100644 index 0000000000..8c730ea9ff --- /dev/null +++ b/scripts/sam/setup.sh @@ -0,0 +1,22 @@ + +SETUP_HOME=$(pwd) + + +mkdir -p checkpoints +mkdir -p datasets + +mkdir -p tmp +mkdir -p tmp/sam_coco_mask_center_cache +mkdir -p tmp/sam_eval_masks_out + +wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth +wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth +wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth + +mkdir -p datasets/coco2017 +wget -nc -P datasets/coco2017 http://images.cocodataset.org/zips/val2017.zip +wget -nc -P datasets/coco2017 http://images.cocodataset.org/annotations/annotations_trainval2017.zip + +cd datasets/coco2017 && unzip -n val2017.zip && cd $SETUP_HOME +cd datasets/coco2017 && unzip -n annotations_trainval2017.zip && cd $SETUP_HOME + diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index bc0a61b202..239d52bb0e 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -20,29 +20,32 @@ More concretely, we hope to provide tutorials and APIs for both sparse kernels ( ## Success Stories -#### segment-anything +#### segment-anything-fast We applied 2:4 sparsity to accelerate segment-anything, as part of [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast). -The results mentioned in the README of the repo compose sparsity with a suite of other inference acceleration techniques. -From our [benchmarking](https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_sam.py), we see a 1.1x speedup when running with `SEGMENT_ANYTHING_FAST_USE_FLASH_4` enabled. -To reproduce these benchmarks you can run the following command: +We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. -The inference acceleration of semi-structured sparsity depends on the matmul shapes, which is why we don't see additional speedups when applying to all linear layers (attn + mlp) of segment-anything. -We find that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`). To repoduce our benchmarks you can run the following command: +Overall, we found that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`), while mitigating accuracy loss. -``` -python benchmarks/benchmark_sam.py -``` +Applying sparsity to the attention linear layers led to a slower model, likely due to two reasons: +- We cannot fuse into our semi-structured sparse matmul with torch.compile. +- The speedups we observe for sparse matmul depend on the matmul shapes, and the attention matmuls are smaller than the MLP ones. + +We were also are able to compose int8 dynamic quantization with 2:4 sparsity for futher speedups. + +We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. + +The following benchmarks we ran for sam ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`: -The following benchmarks we run on an A100, with batch_size=32 and `bfloat16` dtype: +| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy | +|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------| +| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | | +| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** | +| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** | +| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** | +| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** | -| qkv | proj | lin1 | lin2 | time | memory | img/s | -| ---- | ---- | ---- | ---- | ---- | ------ | ----- | -| None | None | None | None | 1361.73 | 15.81 | 23.50 | -| None | None | sparse (cusparselt) | sparse (cusparselt) | 1245.15 | 15.46 | 25.70 | -| None | None | sparse (cutlass) | sparse (cutlass) | 1251.047651 | 15.41 | 25.59 | -| sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | sparse (cusparselt) | 1265.43 | 12.71 | 25.29| -| sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | sparse (cutlass) | 1274.96 | 12.70 | 25.10 | +To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md). #### BERT diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py index 6d7a856ea4..2601f166a8 100644 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py @@ -149,6 +149,17 @@ def sparse_quant_int8_cutlass_matmul( class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( Int8DynamicallyQuantizedLinearWeight ): + def dequantize(self, dtype=None): + # overload dequantize op for __repr__ + zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype) + int_data_expanded = torch._cslt_sparse_mm(self.int_data, torch.eye(self.shape[1], + dtype=self.int_data.dtype, + device=self.int_data.device)) + dq_t = dequantize_per_channel( + int_data_expanded, self.q_scales, zero_points, self.dtype if dtype is None else dtype + ).to(self.dtype) + + return dq_t if not self.transposed else dq_t.t() @staticmethod def _quantized_op(act_mat, w_qtensor, bias): @@ -158,7 +169,7 @@ def _quantized_op(act_mat, w_qtensor, bias): ) @classmethod - def from_float(cls, input_float, qmin=-8, qmax=7): + def from_float(cls, input_float, qmin=-128, qmax=127): assert input_float.is_cuda diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 90e35a4121..d8ec14a266 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -4,15 +4,16 @@ from torchao.quantization.quant_api import _is_linear # Sparsity helper functions -def apply_fake_sparsity(model): +def apply_fake_sparsity(model, **kwargs): """ This function simulates 2:4 sparsity on all linear layers in a model. It uses the torch.ao.pruning flow. """ + filter_fn = kwargs.pop("filter_fn", _is_linear) # torch.ao.pruning flow sparse_config = [] for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear): + if filter_fn(mod, name): sparse_config.append({"tensor_fqn": f"{name}.weight"}) sparsifier = WeightNormSparsifier( @@ -26,7 +27,7 @@ def apply_fake_sparsity(model): def apply_sparse_semi_structured(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) - apply_fake_sparsity(model) + apply_fake_sparsity(model, filter_fn=filter_fn) for name, mod in model.named_modules(): if filter_fn(mod, name): mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))