A rectified flow implementation in Flax
This repository implements a rectified flow which has been proposed in Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow using JAX and Flax.
The experiments
folder contains a use case where samples from the "Wto Moonss" data set are transported
to the "Eight Gaussian" data set. To train a model and make visualizations, call:
cd experiments/eight_gaussians_two_moons
python main.py
Shown below are samples from the two moons data set (black) that have been transported to the eight Gaussians data set(blue). Each figure shows the transport map after x training iterations.
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/rflow@<RELEASE>
Simon Dirmeier sfyrbnd @ pm me