This repository contains a PyTorch implementation of the Mixed-curvature Variational Autoencoder, or M-VAE, as presented in [1]. For the arXiv paper, please see: https://arxiv.org/abs/1911.08411.
Install Python 3.7+. To install all dependencies, make sure you have installed conda, and run
make conda
conda activate pt
make download_data
chkpt/
- Checkpoints for trained models.data/
- Data folder. Contains a script necessary for downloading the datasets, and the downloaded data.mathematica/
- Mathematica scripts (various formula derivations, etc).mt/
- Source folder (stands for Master Thesis).data/
- Data loading, preprocessing, batching, and pre-trained embeddings.examples/
- Contains the main executable file. Reads flags and runs the corresponding training and/or evaluation.mvae/
- Model directory. Note that models heavily use inheritance!test_data/
- Data used for testing.visualization/
- Utilities for visualization of latent spaces or training statistics.
plots/
- Folder to store generated plots.scripts/
- Contains scripts to run experiments and plot the results.tests/
- (A few) unit tests.Makefile
- Defines "aliases" for various tasks.README.md
- This manual.LICENSE
- Apache Standard License 2.0.environment.yml
- Required Python packages.THIRD_PARTY.md
- List of third party software used in this thesis.
To run training and inference, activate the created conda environment and run the examples:
conda activate pt
# MNIST:
python -m mt.examples.run --dataset="mnist" --model="h2,s2,e2" --fixed_curvature=False
# CIFAR:
python -m mt.examples.run --dataset="cifar" --model="h2,s2,e2" --fixed_curvature=False --h_dim=8192 --architecture="conv"
Take a look at mt/examples/run.py
for a list of command line arguments.
For an evaluation run, see mt/examples/eval.py
.
Please cite [1] in your work when using this repository in your experiments.
make clean # format source code
make check # check for formatting and code errors
make test # run tests
For questions and comments, feel free to contact Ondrej Skopek.
ASL 2.0
Ondrej Skopek, Octavian-Eugen Ganea, Gary Bécigneul. Mixed-curvature Variational Autoencoders. International Conference on Learning Representations (ICLR) 2020. URL https://openreview.net/forum?id=S1g6xeSKDS
BibTeX format:
@inproceedings{skopek2020mixedcurvature,
title={Mixed-curvature Variational Autoencoders},
author={Ondrej Skopek and Octavian-Eugen Ganea and Gary B{\'e}cigneul,
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=S1g6xeSKDS}
}