Skip to content

Train different forms of Variational Autoencoders (VAEs) on the Fashion MNIST dataset.

Notifications You must be signed in to change notification settings

gCaglia/fashion-mnist-vae

Repository files navigation

Fashion MNIST - VAE

Purpose

The purpose of this repository is to train different forms of Variational Autoencoders (VAEs) on the Fashion MNIST dataset to better understand the capabilities and inner workings of this modeling technique, specifically for generative purposes. Furthermore, a secondary goal is to become more familiar with the Python package Pyro for probabilistic programming, with which the VAEs in this repo have been trained.

Reference

The work in this repository is heavily inspired by the VAE Tutorials found in the documentation of the Pyro project, although the used data is the FashionMNIST dataset and the specific implementation deviates from the Pyro tutorial.

For an introduction to VAEs with Pyro, please see here.

For a more general introduction to Pyro with an example implementation of a VAE for the MNIST Digits dataset, including exercises and solutions, please see here.

For a discussion of the choices made in this repository, please refer to this Medium Article.

Getting Started

Dependencies

To use the code in this repository, please clone it to your local or remote machine with:

git clone https://github.com/GiuliaCaglia/fashion-mnist-vae.git

All dependencies can be installed with

poetry install

if you use poetry for dependency management, or

pip3 install -r requirements.txt

otherwise.

When not using poetry, it is recommended, to use another environment management system, such as pipenv, conda or virtualenv.

Note, that the requirements.txt file includes cuda dependencies and that the required Python version is >=3.11. This repository might work with other Python versions, but it has not been tested for them.

Training Models

Once dependencies are installed, a VAE can be trained with:

python3 src/fashion_mnist_vae/scripts/train_vae.py -e $EPOCHS -d $DEVICE

Where:

  • EPOCHS denotes the number of epochs for which to train the model
  • DEVICE is one of cuda or cpu (default)

The loss plot, example images and the model itself will be stored in the assets folder, once training is concluded.

Results

The results are assessed in more detail in notebooks/Visualizations & Samples.ipynb.

Here, we see samples generated by the VAE:

image

And here we see how two items morph into each other as we move through the latent space. image

About

Train different forms of Variational Autoencoders (VAEs) on the Fashion MNIST dataset.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published