-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sparse Flow (Node) paper and v2.2.0 release
- Loading branch information
Lucas Liebenwein
committed
Nov 16, 2022
1 parent
2e59069
commit 14b392c
Showing
211 changed files
with
16,859 additions
and
3,126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Sparse flows: Pruning continuous-depth models | ||
[Lucas Liebenwein*](https://people.csail.mit.edu/lucasl/), | ||
[Ramin Hasani*](http://www.raminhasani.com), | ||
[Alexander Amini](https://www.mit.edu/~amini/), | ||
[Daniela Rus](http://danielarus.csail.mit.edu/) | ||
|
||
***Equal contribution** | ||
|
||
<p align="center"> | ||
<img align="center" src="../../misc/imgs/node_overview.png" width="80%"> | ||
</p> | ||
<!-- <br clear="left"/> --> | ||
|
||
Continuous deep learning architectures enable learning of flexible | ||
probabilistic models for predictive modeling as neural ordinary differential | ||
equations (ODEs), and for generative modeling as continuous normalizing flows. | ||
In this work, we design a framework to decipher the internal dynamics of these | ||
continuous depth models by pruning their network architectures. Our empirical | ||
results suggest that pruning improves generalization for neural ODEs in | ||
generative modeling. We empirically show that the improvement is because | ||
pruning helps avoid mode- collapse and flatten the loss surface. Moreover, | ||
pruning finds efficient neural ODE representations with up to 98% less | ||
parameters compared to the original network, without loss of accuracy. We hope | ||
our results will invigorate further research into the performance-size | ||
trade-offs of modern continuous-depth models. | ||
|
||
## Setup | ||
Check out the main [README.md](../../README.md) and the respective packages for | ||
more information on the code base. | ||
|
||
## Overview | ||
|
||
### Run compression experiments | ||
The experiment configurations are located [here](./param). To reproduce the | ||
experiments for a specific configuration, run: | ||
```bash | ||
python -m experiment.main param/toy/ffjord/spirals/vanilla_l4_h64.yaml | ||
``` | ||
|
||
The pruning experiments will be run fully automatically and store all the | ||
results. | ||
|
||
### Experimental evaluations | ||
|
||
The [script](./script) contains the evaluation and plotting scripts to | ||
evaluate and analyze the various experiments. Please take a look at each of | ||
them to understand how to load the pruning experiments and how to analyze | ||
the pruning experiments. | ||
|
||
Each plot and experiment presented in the paper can be reproduced this way. | ||
|
||
## Citation | ||
Please cite the following paper when using our work. | ||
|
||
### Paper link | ||
[Sparse flows: Pruning continuous-depth models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) | ||
|
||
### Bibtex | ||
``` | ||
@article{liebenwein2021sparse, | ||
title={Sparse flows: Pruning continuous-depth models}, | ||
author={Liebenwein, Lucas and Hasani, Ramin and Amini, Alexander and Rus, Daniela}, | ||
journal={Advances in Neural Information Processing Systems}, | ||
volume={34}, | ||
pages={22628--22642}, | ||
year={2021} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
network: | ||
name: "ffjord_multiscale_cifar" | ||
dataset: "CIFAR10" | ||
outputSize: 10 | ||
|
||
training: | ||
transformsTrain: | ||
- type: RandomHorizontalFlip | ||
kwargs: {} | ||
transformsTest: [] | ||
transformsFinal: | ||
- type: Resize | ||
kwargs: { size: 32 } | ||
- type: ToTensor | ||
kwargs: {} | ||
- type: RandomNoise | ||
kwargs: { "normalization": 255.0 } | ||
|
||
loss: "NLLBitsLoss" | ||
lossKwargs: {} | ||
|
||
metricsTest: | ||
- type: NLLBits | ||
kwargs: {} | ||
- type: Dummy | ||
kwargs: {} | ||
|
||
batchSize: 200 # don't change that since it's hard-coded | ||
|
||
optimizer: "Adam" | ||
optimizerKwargs: | ||
lr: 1.0e-3 | ||
weight_decay: 0.0 | ||
|
||
numEpochs: 50 | ||
earlyStopEpoch: 0 | ||
enableAMP: False | ||
|
||
lrSchedulers: | ||
- type: MultiStepLR | ||
stepKwargs: { milestones: [45] } | ||
kwargs: { gamma: 0.1 } | ||
|
||
file: "paper/node/param/directories.yaml" | ||
|
||
retraining: | ||
startEpoch: 0 | ||
|
||
experiments: | ||
methods: | ||
- "ThresNet" | ||
- "FilterThresNet" | ||
mode: "cascade" | ||
|
||
numRepetitions: 1 | ||
numNets: 1 | ||
|
||
plotting: | ||
minVal: 0.02 | ||
maxVal: 0.85 | ||
|
||
spacing: | ||
- type: "geometric" | ||
numIntervals: 12 | ||
maxVal: 0.80 | ||
minVal: 0.05 | ||
|
||
retrainIterations: -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
network: | ||
name: "ffjord_multiscale_mnist" | ||
dataset: "MNIST" | ||
outputSize: 10 | ||
|
||
training: | ||
transformsTrain: [] | ||
transformsTest: [] | ||
transformsFinal: | ||
- type: Resize | ||
kwargs: { size: 28 } | ||
- type: ToTensor | ||
kwargs: {} | ||
- type: RandomNoise | ||
kwargs: { "normalization": 255.0 } | ||
|
||
loss: "NLLBitsLoss" | ||
lossKwargs: {} | ||
|
||
metricsTest: | ||
- type: NLLBits | ||
kwargs: {} | ||
- type: Dummy | ||
kwargs: {} | ||
|
||
batchSize: 200 # don't change that since it's hard-coded | ||
|
||
optimizer: "Adam" | ||
optimizerKwargs: | ||
lr: 1.0e-3 | ||
weight_decay: 0.0 | ||
|
||
numEpochs: 50 | ||
earlyStopEpoch: 0 | ||
enableAMP: False | ||
|
||
lrSchedulers: | ||
- type: MultiStepLR | ||
stepKwargs: { milestones: [45] } | ||
kwargs: { gamma: 0.1 } | ||
|
||
file: "paper/node/param/directories.yaml" | ||
|
||
retraining: | ||
startEpoch: 0 | ||
|
||
experiments: | ||
methods: | ||
- "ThresNet" | ||
- "FilterThresNet" | ||
mode: "cascade" | ||
|
||
numRepetitions: 1 | ||
numNets: 1 | ||
|
||
plotting: | ||
minVal: 0.02 | ||
maxVal: 0.85 | ||
|
||
spacing: | ||
- type: "geometric" | ||
numIntervals: 12 | ||
maxVal: 0.80 | ||
minVal: 0.05 | ||
|
||
retrainIterations: -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# relative directories from where main.py was called | ||
directories: | ||
results: "./data/node/results" | ||
trained_networks: null | ||
training_data: "./data/training" | ||
local_data: "./local" |
Oops, something went wrong.