Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge #312

Merged
merged 13 commits into from
Dec 14, 2023
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<br>
<div align="center">
<img src="https://user-images.githubusercontent.com/18592211/232830417-0b21a874-516e-4420-8984-4de414a35085.png" width="400px"></img>
<h2></h2>
<h3>Towards Any Structural Pruning<h3>

<img src="assets/intro.png" width="50%">
</div>

Expand All @@ -11,7 +11,7 @@
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.8 %20%7C%201.12 %20%7C%202.0-673ab7.svg" alt="Tested PyTorch Versions"></a>
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-4caf50.svg" alt="License"></a>
<a href="https://pepy.tech/project/Torch-Pruning"><img src="https://static.pepy.tech/badge/Torch-Pruning?color=2196f3" alt="Downloads"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.3.2-3f51b5.svg" alt="Latest Version"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.3.3-3f51b5.svg" alt="Latest Version"></a>
<a href="https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
Expand All @@ -23,8 +23,8 @@

Torch-Pruning (TP) is a library for structural pruning with the following features:

* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including [Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [Vision Transformers](examples/transformers/), [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to remove parameters physically. Check our [Paper List](https://github.com/VainF/Torch-Pruning/wiki/0.-Paper-List) for more details.
* **[Examples](examples)**: Play around with off-the-shelf models from Huggingface Transformers, Timm, Torchvision, Yolo, etc.
* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including [Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [Vision Transformers](examples/transformers/), [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to remove parameters physically.
* **[Examples](examples)**: Pruning off-the-shelf models from Timm, Huggingface Transformers, Torchvision, Yolo, etc.
* **[Benchmark](benchmarks)**: Reproduce the our results in the DepGraph paper.

For more technical details, please refer to our CVPR'23 paper:
Expand Down Expand Up @@ -52,7 +52,7 @@ For more technical details, please refer to our CVPR'23 paper:
Please do not hesitate to open an [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper.
Or Join our Discord or WeChat group for a chat:
* Discord: [link](https://discord.gg/Z6r34MnE)
* WeChat (Group size exceeded 300): [QR Code](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3)
* WeChat Group (Group size exceeded 400): [QR Code](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3)

## Table of Contents
- [Installation](#installation)
Expand All @@ -77,7 +77,7 @@ Or Join our Discord or WeChat group for a chat:

## Installation

Torch-Pruning is compatible with both PyTorch 1.x and 2.x versions. However, it is highly recommended to use PyTorch 1.12.1 or higher.
Torch-Pruning is compatible with both PyTorch 1.x and 2.x versions. However, it is highly recommended to use PyTorch 2.0.

```bash
pip install torch-pruning
Expand All @@ -93,7 +93,7 @@ Here we provide a quick start for Torch-Pruning. More explained details can be f

### How It Works

In structural pruning, a "Group" is defined as the minimal unit that can be removed within deep networks. These groups are composed of multiple layers that are interdependent and need to be pruned together in order to maintain the integrity of the resulting structures. However, deep networks often have complex dependencies among their layers, making structural pruning a challenging task. This work addresses this challenge by introducing an automated mechanism called "DepGraph." DepGraph allows for seamless parameter grouping and facilitates pruning in various types of deep networks.
In structural pruning, a "Group" is defined as the minimal removable unit within deep networks. Most groups are composed of multiple layers that are interdependent and need to be pruned together in order to maintain the integrity of the resulting structures. However, deep networks often have complex dependencies among their layers, making structural pruning a challenging task. This work addresses this challenge by introducing an automated mechanism called "DepGraph." DepGraph allows for seamless parameter grouping and facilitates pruning in various types of deep networks.

<div align="center">
<img src="assets/dep.png" width="100%">
Expand Down Expand Up @@ -419,6 +419,14 @@ Please refer to [benchmarks](benchmarks) for more details.
> *Gongfan Fang, Xinyin Ma, Xinchao Wang*
> NeurIPS 2023

> **DeepCache: Accelerating Diffusion Models for Free** [[Project]](https://github.com/horseee/DeepCache) [[Arxiv]](https://arxiv.org/abs/2312.00858)
> *Xinyin Ma, Gongfan Fang, and Xinchao Wang*
> Preprint 2023

> **0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284)
> *Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang*
> Preprint 2023


## Citation
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="torch-pruning",
version="v1.3.2",
version="v1.3.3",
author="Gongfan Fang",
author_email="gongfan@u.nus.edu",
description="Towards Any Structural Pruning",
Expand Down
4 changes: 2 additions & 2 deletions torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def regularize(self, model, reg=None, bias=False):
m.weight.grad.data.add_(reg*torch.sign(m.weight.data))
else:
for group in self._groups:
group_l2norm_sq = self._l2_imp(group)
if group_l2norm_sq is None:
group_l2norm_sq = self._l2_imp(group) + 1e-9 # + 1e-9 to avoid inf
if group_l2norm_sq is None or torch.any(torch.isnan(group_l2norm_sq)): # avoid nan
continue
gamma = reg * (1 / group_l2norm_sq.sqrt())

Expand Down
2 changes: 2 additions & 0 deletions torch_pruning/pruner/algorithms/group_norm_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def regularize(self, model, alpha=2**4, bias=False):
for i, group in enumerate(self._groups):
ch_groups = self._get_channel_groups(group)
imp = self.estimate_importance(group).sqrt()
if torch.any(torch.isnan(imp)): # avoid nan
continue
gamma = alpha**((imp.max() - imp) / (imp.max() - imp.min()))

# Update Gradient
Expand Down
2 changes: 1 addition & 1 deletion torch_pruning/pruner/algorithms/growing_reg_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def step(self, interactive=False):
def regularize(self, model, bias=False):
for i, group in enumerate(self._groups):
group_l2norm_sq = self.estimate_importance(group)
if group_l2norm_sq is None:
if group_l2norm_sq is None or torch.any(torch.isnan(group_l2norm_sq)): # avoid nan
continue
gamma = self.group_reg[group]
for k, (dep, idxs) in enumerate(group):
Expand Down