Skip to content
/ DDEQs Public

Implementation of Distributional Deep Equilibrium Models.

License

Notifications You must be signed in to change notification settings

j-geuter/DDEQs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Distributional Deep Equilibrium Models

Welcome to the DDEQ repository! This is the official repository for the paper "DDEQs: Distributional Deep Equilibrium Models through Wasserstein Gradient Flows" (AISTATS 2025).

Installation

To install the necessary dependencies, run:

pip install -r requirements.txt

Usage

To use this project, you first need to download the datasets, namely MNIST Point Cloud and ModelNet40. The MNIST files should be saved in the MNISTPointCloud folder. Save the ModelNet40 dataset somewhere and change DATA_PATH in src/modelnet.py accordingly. The dataset can then be created using the load_modelnet function in that file. Since creating the dataset takes some time, it's a good idea to save the dataset, which can then be loaded with the load_modelnet_saved function.

Once the datasets are set up, simply run the train_torchdeq.py file.

Citation

If you find this implementation useful, please consider citing our paper:

@inproceedings{
geuter2025ddeqs,
title={{DDEQ}s: Distributional Deep Equilibrium Models through Wasserstein Gradient Flows},
author={Jonathan Geuter and Cl{\'e}ment Bonet and Anna Korba and David Alvarez-Melis},
booktitle={The 28th International Conference on Artificial Intelligence and Statistics},
year={2025},
url={https://openreview.net/forum?id=rFfNuzzXXW}
}

License

This project is licensed under the GNU License.

About

Implementation of Distributional Deep Equilibrium Models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages