-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Marc Rigter <t-marcrigter@microsoft.com>
- Loading branch information
1 parent
07de467
commit d4ca28a
Showing
178 changed files
with
34,192 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Adapting Video Diffusion Models to World Models | ||
### Marc Rigter, Tarun Gupta, Agrin Hilmkil and Chao Ma | ||
|
||
[![Project Page](https://img.shields.io/badge/Project%20Page-green)](https://sites.google.com/view/avid-world-model-adapters/home) | ||
|
||
<p float="left"> | ||
<img src="images/rt1_1.gif" width="142" /> | ||
<img src="images/rt1_2.gif" width="142" /> | ||
<img src="images/rt1_3.gif" width="142" /> | ||
<img src="images/rt1_4.gif" width="142" /> | ||
<img src="images/rt1_5.gif" width="142" /> | ||
</p> | ||
<p float="left"> | ||
<img src="images/coinrun1.gif" width="100" /> | ||
<img src="images/coinrun2.gif" width="100" /> | ||
<img src="images/coinrun3.gif" width="100" /> | ||
<img src="images/coinrun4.gif" width="100" /> | ||
<img src="images/coinrun5.gif" width="100" /> | ||
<img src="images/coinrun6.gif" width="100" /> | ||
<img src="images/coinrun7.gif" width="100" /> | ||
</p> | ||
|
||
Official code to reproduce the experiments for ["Adapting Video Diffusion Models to World Models"](https://sites.google.com/view/avid-world-model-adapters/home) (AVID), which proposes to adapt pretrained video diffusion models to action-conditioned world models. | ||
|
||
AVID is implemented using both pixel-space diffusion and latent space diffusion. For instructions on how to use each of the codebases, please see [pixel_diffusion/README.md](pixel_diffusion/README.md) and [latent_diffusion/README.md](latent_diffusion/README.md). Results are logged to Weights and Biases. | ||
|
||
|
||
### Acknowledgements | ||
|
||
This project utilises code from [video-diffusion-pytorch](https://github.com/lucidrains/video-diffusion-pytorch), [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), and [octo](https://github.com/octo-models/octo). |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Adapting Video Diffusion Models to World Models (Latent Space) | ||
|
||
Implements the latent diffusion experiments for RT1. | ||
|
||
## Installing dependencies | ||
We use [Poetry](https://python-poetry.org/) to manage the project dependencies, they are specified in [pyproject](pyproject.toml) file. To install poetry, run: | ||
|
||
```console | ||
curl -sSL https://install.python-poetry.org | python3 - | ||
``` | ||
To install the environment, run `poetry install` in avid/latent_diffusion. | ||
|
||
Loading the RT1 data requires using tensorflow datasets. To avoid the dataloader consuming unnecessarily large amounts of CPU memory, we recommend using tcmalloc (see [here](https://github.com/tensorflow/tensorflow/issues/44176) for details). It can be installed using: | ||
``` | ||
sudo apt update | ||
sudo apt install libtcmalloc-minimal4 | ||
export LD_PRELOAD="/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4" | ||
``` | ||
|
||
## Downloading DynamiCrafter checkpoint | ||
We use the 512 x 320 resolution version of [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter) as the pretrained base model. The checkpoint is available on HuggingFace and can be downloaded with: | ||
``` | ||
wget https://huggingface.co/Doubiiu/DynamiCrafter_512/resolve/main/model.ckpt | ||
``` | ||
The AVID codebase expects the checkpoint to be located at /host_home/avid/dynamicrafter_512/model.ckpt. You will need to update the configs if the checkpoint is stored elsewhere. | ||
|
||
## Training models | ||
To train the models, run: | ||
``` | ||
./scripts/train.sh --config configs/train/{CONFIG}.yaml --script scripts/train_{MODEL_TYPE}.py | ||
``` | ||
|
||
For example, to train each of the 145M/170M models: | ||
``` | ||
./scripts/train.sh --config configs/train/avid/avid_145M.yaml --script scripts/train_avid.py | ||
./scripts/train.sh --config configs/train/control_net/control_net_lite_170M.yaml --script scripts/train_control_net.py | ||
./scripts/train.sh --config configs/train/act_cond_diffusion_145M.yaml --script scripts/train_diffusion.py | ||
``` | ||
|
||
The number of GPUs used for training is set in train.sh. Note that in our experiments we use 4 A100 GPUs to train each model, giving a global batch size of 64. Ensure to *adjust the batch size and gradient accumulation if you are using a different number of GPUs* for training. | ||
|
||
|
||
## Evaluating models | ||
Update the config files in config/eval to point to the correct checkpoint. To evaluate each of the baselines, the command is of the format: | ||
```console | ||
python scripts/eval/eval_{MODEL_TYPE}.py --config config/eval/{CONFIG}.yaml | ||
``` | ||
|
||
For example, to evaluate each of the 145M/170M models, you would run: | ||
```console | ||
python scripts/eval/eval_avid.py --config configs/eval/avid_145M.yaml | ||
python scripts/eval/eval_diffusion.py --config configs/eval/act_cond_diffusion_145M.yaml | ||
python scripts/eval/eval_controlnet.py --config configs/eval/control_net_170M.yaml | ||
``` |
41 changes: 41 additions & 0 deletions
41
research_experiments/avid/latent_diffusion/configs/eval/act_cond_diffusion_145M.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
name: act_cond_diffusion_145M | ||
group: act_cond_diffusion_145M | ||
|
||
num_batches: 64 | ||
|
||
model_config_file: configs/train/act_cond_diffusion_145M.yaml | ||
act_cond_unet_checkpoint: /host_home/avid/act_cond_model_145M/model.ckpt | ||
|
||
ddim_kwargs: | ||
ddim_steps: 50 | ||
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class | ||
timestep_spacing: uniform_trailing | ||
guidance_rescale: 0.7 | ||
ddim_eta: 1.0 | ||
verbose: True | ||
|
||
logger: | ||
target: pytorch_lightning.loggers.WandbLogger | ||
params: | ||
save_dir: /host_home/wandb/ | ||
offline: False | ||
project: avid-eval | ||
entity: causica | ||
|
||
video_logger_callback: | ||
target: lvdm.utils.callbacks.ImageLogger | ||
params: | ||
batch_frequency: 1 | ||
reset_metrics_per_batch: False # aggregate metrics over all batches | ||
max_wandb_images: 4 | ||
|
||
data: | ||
target: ldwma.lightning.data_modules.rtx.RTXDataModule | ||
params: | ||
batch_size: 16 | ||
target_height: 320 | ||
target_width: 512 | ||
dataset_name: "fractal20220817_data" | ||
shuffle_buffer: 100 | ||
traj_len: 16 | ||
deterministic: True # slow but deterministic data loading for eval |
50 changes: 50 additions & 0 deletions
50
research_experiments/avid/latent_diffusion/configs/eval/avid_145M.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
name: avid_145M | ||
group: avid_145M | ||
|
||
num_batches: 64 | ||
|
||
base_config_file: configs/train/dynamicrafter_512.yaml | ||
action_config_file: configs/train/act_cond_diffusion_145M.yaml | ||
|
||
base_model_checkpoint: /host_home/avid/dynamicrafter_512/model.ckpt | ||
action_model_checkpoint: /host_home/avid/avid_145M/model.ckpt | ||
|
||
target_module: ldwma.models.avid.AVIDAdapter | ||
adapter_params: | ||
condition_adapter_on_base_outputs: True | ||
learnt_mask: True | ||
init_mask_bias: 0.0 | ||
|
||
ddim_kwargs: | ||
ddim_steps: 50 | ||
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class | ||
timestep_spacing: uniform_trailing | ||
guidance_rescale: 0.7 | ||
ddim_eta: 1.0 | ||
verbose: True | ||
|
||
logger: | ||
target: pytorch_lightning.loggers.WandbLogger | ||
params: | ||
save_dir: /host_home/wandb/ | ||
offline: False | ||
project: avid-eval | ||
entity: causica | ||
|
||
video_logger_callback: | ||
target: lvdm.utils.callbacks.ImageLogger | ||
params: | ||
batch_frequency: 1 | ||
reset_metrics_per_batch: False # aggregate metrics over all batches | ||
max_wandb_images: 4 | ||
|
||
data: | ||
target: ldwma.lightning.data_modules.rtx.RTXDataModule | ||
params: | ||
batch_size: 16 | ||
target_height: 320 | ||
target_width: 512 | ||
dataset_name: "fractal20220817_data" | ||
shuffle_buffer: 100 | ||
traj_len: 16 | ||
deterministic: True # slow but deterministic data loading for eval |
46 changes: 46 additions & 0 deletions
46
research_experiments/avid/latent_diffusion/configs/eval/control_net_170M.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: control_net_170M | ||
group: control_net_170M | ||
|
||
num_batches: 64 | ||
|
||
base_config_file: configs/train/control_net/dynamicrafter_512.yaml | ||
control_config_file: configs/train/control_net/act_control_lite_170M.yaml | ||
|
||
base_model_checkpoint: /host_home/avid/dynamicrafter_512/model.ckpt | ||
action_model_checkpoint: /host_home/avid/control_net_170M/model.ckpt | ||
|
||
target_module: ldwma.models.control_net.ControlNetAdapter | ||
|
||
ddim_kwargs: | ||
ddim_steps: 50 | ||
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class | ||
timestep_spacing: uniform_trailing | ||
guidance_rescale: 0.7 | ||
ddim_eta: 1.0 | ||
verbose: True | ||
|
||
logger: | ||
target: pytorch_lightning.loggers.WandbLogger | ||
params: | ||
save_dir: /host_home/wandb/ | ||
offline: False | ||
project: avid-eval | ||
entity: causica | ||
|
||
video_logger_callback: | ||
target: lvdm.utils.callbacks.ImageLogger | ||
params: | ||
batch_frequency: 1 | ||
reset_metrics_per_batch: False # aggregate metrics over all batches | ||
max_wandb_images: 4 | ||
|
||
data: | ||
target: ldwma.lightning.data_modules.rtx.RTXDataModule | ||
params: | ||
batch_size: 16 | ||
target_height: 320 | ||
target_width: 512 | ||
dataset_name: "fractal20220817_data" | ||
shuffle_buffer: 100 | ||
traj_len: 16 | ||
deterministic: True # slow but deterministic data loading for eval |
46 changes: 46 additions & 0 deletions
46
research_experiments/avid/latent_diffusion/configs/eval/control_net_full.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: control_net_full | ||
group: control_net_full | ||
|
||
num_batches: 64 | ||
|
||
base_config_file: configs/train/control_net/dynamicrafter_512.yaml | ||
control_config_file: configs/train/control_net/act_control.yaml | ||
|
||
base_model_checkpoint: /host_home/avid/dynamicrafter_512/model.ckpt | ||
action_model_checkpoint: /host_home/avid/control_net_full/model.ckpt | ||
|
||
target_module: ldwma.models.control_net.ControlNetAdapter | ||
|
||
ddim_kwargs: | ||
ddim_steps: 50 | ||
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class | ||
timestep_spacing: uniform_trailing | ||
guidance_rescale: 0.7 | ||
ddim_eta: 1.0 | ||
verbose: True | ||
|
||
logger: | ||
target: pytorch_lightning.loggers.WandbLogger | ||
params: | ||
save_dir: /host_home/wandb/ | ||
offline: False | ||
project: avid-eval | ||
entity: causica | ||
|
||
video_logger_callback: | ||
target: lvdm.utils.callbacks.ImageLogger | ||
params: | ||
batch_frequency: 1 | ||
reset_metrics_per_batch: False # aggregate metrics over all batches | ||
max_wandb_images: 4 | ||
|
||
data: | ||
target: ldwma.lightning.data_modules.rtx.RTXDataModule | ||
params: | ||
batch_size: 16 | ||
target_height: 320 | ||
target_width: 512 | ||
dataset_name: "fractal20220817_data" | ||
shuffle_buffer: 100 | ||
traj_len: 16 | ||
deterministic: True # slow but deterministic data loading for eval |
40 changes: 40 additions & 0 deletions
40
research_experiments/avid/latent_diffusion/configs/eval/dynamicrafter_pretrained.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
name: dynamicrafter_pretrained | ||
group: dynamicrafter_pretrained | ||
|
||
num_batches: 64 | ||
|
||
model_config_file: configs/train/dynamicrafter_512.yaml | ||
|
||
ddim_kwargs: | ||
ddim_steps: 50 | ||
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class | ||
timestep_spacing: uniform_trailing | ||
guidance_rescale: 0.7 | ||
ddim_eta: 1.0 | ||
verbose: True | ||
|
||
logger: | ||
target: pytorch_lightning.loggers.WandbLogger | ||
params: | ||
save_dir: /host_home/wandb/ | ||
offline: False | ||
project: avid-eval | ||
entity: causica | ||
|
||
video_logger_callback: | ||
target: lvdm.utils.callbacks.ImageLogger | ||
params: | ||
batch_frequency: 1 | ||
reset_metrics_per_batch: False # aggregate metrics over all batches | ||
max_wandb_images: 4 | ||
|
||
data: | ||
target: ldwma.lightning.data_modules.rtx.RTXDataModule | ||
params: | ||
batch_size: 16 | ||
target_height: 320 | ||
target_width: 512 | ||
dataset_name: "fractal20220817_data" | ||
shuffle_buffer: 100 | ||
traj_len: 16 | ||
deterministic: True # slow but deterministic data loading for eval |
41 changes: 41 additions & 0 deletions
41
research_experiments/avid/latent_diffusion/configs/eval/full_finetune.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
name: dynamicrafter_finetune | ||
group: dynamicrafter_finetune | ||
|
||
num_batches: 64 | ||
|
||
model_config_file: configs/train/dynamicrafter_512_action_finetune.yaml | ||
act_cond_unet_checkpoint: /host_home/avid/dynamicrafter_finetune/model.ckpt | ||
|
||
ddim_kwargs: | ||
ddim_steps: 50 | ||
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class | ||
timestep_spacing: uniform_trailing | ||
guidance_rescale: 0.7 | ||
ddim_eta: 1.0 | ||
verbose: True | ||
|
||
logger: | ||
target: pytorch_lightning.loggers.WandbLogger | ||
params: | ||
save_dir: /host_home/wandb/ | ||
offline: False | ||
project: avid-eval | ||
entity: causica | ||
|
||
video_logger_callback: | ||
target: lvdm.utils.callbacks.ImageLogger | ||
params: | ||
batch_frequency: 1 | ||
reset_metrics_per_batch: False # aggregate metrics over all batches | ||
max_wandb_images: 4 | ||
|
||
data: | ||
target: ldwma.lightning.data_modules.rtx.RTXDataModule | ||
params: | ||
batch_size: 16 | ||
target_height: 320 | ||
target_width: 512 | ||
dataset_name: "fractal20220817_data" | ||
shuffle_buffer: 100 | ||
traj_len: 16 | ||
deterministic: True # slow but deterministic data loading for eval |
Oops, something went wrong.