This repository contains an implementation of the Gaussian Mixture Variational Autoencoder (GMVAE) based on the paper "A Note on Deep Variational Models for Unsupervised Clustering" by James Brofos, Rui Shu, and Curtis Langlotz and a modified version of the M2 model proposed by D. P. Kingma et al. in their paper "Semi-Supervised Learning with Deep Generative Models."
The repository has the following structure:
.
├── configs
│ ├── config.yaml
│ └── model
│ └── gmvae_fc.yaml
├── loss.py
├── models.py
├── modules.py
├── README.md
├── test.ipynb
├── train.py
└── utils.py
configs
: Contains the configuration files for the GMVAE model, includingconfig.yaml
for general settings andgmvae_fc.yaml
for model-specific settings.loss.py
: Implements the loss functions used in the GMVAE.models.py
: Defines the GMVAE model architecture.modules.py
: Contains custom modules used in the GMVAE.train.py
: The main script for training the GMVAE model.utils.py
: Contains utility functions used in the GMVAE implementation.
To train and evaluate the GMVAE model, follow these steps:
- Configure the model settings in
config.yaml
andgmvae_fc.yaml
as needed. - Run the
train.py
script to train the GMVAE model.
Make sure to install the required dependencies before running the code. You may use a virtual environment and install the dependencies using the provided requirements.txt
file.
D. P. Kingma, D. J. Rezende, S. Mohamed, and M. Welling. Semi-Supervised Learning with Deep Generative Models. ArXiv e-prints, June 2014.