This is the training code for a 2 stage autoregressive video model. Code works for training latents. 2nd stage is a WIP.
python -m monkfish.main.main config.json local [args...]
ray start --head --num-cpus=1 --port=6379
PROJECT_SOURCE=path/to/monkfish python -m monkfish.main.main config.json distributed [args...]
ray stop
Parameter scaling:
- A Spectral Condition for Feature Learning
- Scaling Exponents Across Parameterizations and Optimizers (NTK init with global LR is used for most experiments)
Jax sharding:
Data loader Design: