Skip to content

Commit

Permalink
Updated to "Release of AVID" (#109)
Browse files Browse the repository at this point in the history
Co-authored-by: Marc Rigter <t-marcrigter@microsoft.com>
  • Loading branch information
marc-rigter and Marc Rigter authored Oct 4, 2024
1 parent 07de467 commit d4ca28a
Show file tree
Hide file tree
Showing 178 changed files with 34,192 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "causica"
version = "0.4.2"
version = "0.4.3"
description = ""
readme = "README.md"
authors = ["Microsoft Research - Causica"]
Expand Down Expand Up @@ -69,8 +69,15 @@ junit_family = "xunit1"

[tool.mypy]
ignore_missing_imports = true
exclude = [
"research_experiments/avid"
]

[tool.pylint.main]
ignore-paths = [
"research_experiments/avid/latent_diffusion/libs",
]

# Specify a score threshold to be exceeded before program exits with error.
fail-under = 9.94
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
Expand Down
30 changes: 30 additions & 0 deletions research_experiments/avid/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Adapting Video Diffusion Models to World Models
### Marc Rigter, Tarun Gupta, Agrin Hilmkil and Chao Ma

[![Project Page](https://img.shields.io/badge/Project%20Page-green)](https://sites.google.com/view/avid-world-model-adapters/home)

<p float="left">
<img src="images/rt1_1.gif" width="142" />
<img src="images/rt1_2.gif" width="142" />
<img src="images/rt1_3.gif" width="142" />
<img src="images/rt1_4.gif" width="142" />
<img src="images/rt1_5.gif" width="142" />
</p>
<p float="left">
<img src="images/coinrun1.gif" width="100" />
<img src="images/coinrun2.gif" width="100" />
<img src="images/coinrun3.gif" width="100" />
<img src="images/coinrun4.gif" width="100" />
<img src="images/coinrun5.gif" width="100" />
<img src="images/coinrun6.gif" width="100" />
<img src="images/coinrun7.gif" width="100" />
</p>

Official code to reproduce the experiments for ["Adapting Video Diffusion Models to World Models"](https://sites.google.com/view/avid-world-model-adapters/home) (AVID), which proposes to adapt pretrained video diffusion models to action-conditioned world models.

AVID is implemented using both pixel-space diffusion and latent space diffusion. For instructions on how to use each of the codebases, please see [pixel_diffusion/README.md](pixel_diffusion/README.md) and [latent_diffusion/README.md](latent_diffusion/README.md). Results are logged to Weights and Biases.


### Acknowledgements

This project utilises code from [video-diffusion-pytorch](https://github.com/lucidrains/video-diffusion-pytorch), [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), and [octo](https://github.com/octo-models/octo).
Binary file added research_experiments/avid/images/coinrun1.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/coinrun2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/coinrun3.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/coinrun4.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/coinrun5.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/coinrun6.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/coinrun7.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/rt1_1.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/rt1_2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/rt1_3.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/rt1_4.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added research_experiments/avid/images/rt1_5.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 54 additions & 0 deletions research_experiments/avid/latent_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Adapting Video Diffusion Models to World Models (Latent Space)

Implements the latent diffusion experiments for RT1.

## Installing dependencies
We use [Poetry](https://python-poetry.org/) to manage the project dependencies, they are specified in [pyproject](pyproject.toml) file. To install poetry, run:

```console
curl -sSL https://install.python-poetry.org | python3 -
```
To install the environment, run `poetry install` in avid/latent_diffusion.

Loading the RT1 data requires using tensorflow datasets. To avoid the dataloader consuming unnecessarily large amounts of CPU memory, we recommend using tcmalloc (see [here](https://github.com/tensorflow/tensorflow/issues/44176) for details). It can be installed using:
```
sudo apt update
sudo apt install libtcmalloc-minimal4
export LD_PRELOAD="/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4"
```

## Downloading DynamiCrafter checkpoint
We use the 512 x 320 resolution version of [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter) as the pretrained base model. The checkpoint is available on HuggingFace and can be downloaded with:
```
wget https://huggingface.co/Doubiiu/DynamiCrafter_512/resolve/main/model.ckpt
```
The AVID codebase expects the checkpoint to be located at /host_home/avid/dynamicrafter_512/model.ckpt. You will need to update the configs if the checkpoint is stored elsewhere.

## Training models
To train the models, run:
```
./scripts/train.sh --config configs/train/{CONFIG}.yaml --script scripts/train_{MODEL_TYPE}.py
```

For example, to train each of the 145M/170M models:
```
./scripts/train.sh --config configs/train/avid/avid_145M.yaml --script scripts/train_avid.py
./scripts/train.sh --config configs/train/control_net/control_net_lite_170M.yaml --script scripts/train_control_net.py
./scripts/train.sh --config configs/train/act_cond_diffusion_145M.yaml --script scripts/train_diffusion.py
```

The number of GPUs used for training is set in train.sh. Note that in our experiments we use 4 A100 GPUs to train each model, giving a global batch size of 64. Ensure to *adjust the batch size and gradient accumulation if you are using a different number of GPUs* for training.


## Evaluating models
Update the config files in config/eval to point to the correct checkpoint. To evaluate each of the baselines, the command is of the format:
```console
python scripts/eval/eval_{MODEL_TYPE}.py --config config/eval/{CONFIG}.yaml
```

For example, to evaluate each of the 145M/170M models, you would run:
```console
python scripts/eval/eval_avid.py --config configs/eval/avid_145M.yaml
python scripts/eval/eval_diffusion.py --config configs/eval/act_cond_diffusion_145M.yaml
python scripts/eval/eval_controlnet.py --config configs/eval/control_net_170M.yaml
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: act_cond_diffusion_145M
group: act_cond_diffusion_145M

num_batches: 64

model_config_file: configs/train/act_cond_diffusion_145M.yaml
act_cond_unet_checkpoint: /host_home/avid/act_cond_model_145M/model.ckpt

ddim_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
ddim_eta: 1.0
verbose: True

logger:
target: pytorch_lightning.loggers.WandbLogger
params:
save_dir: /host_home/wandb/
offline: False
project: avid-eval
entity: causica

video_logger_callback:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 1
reset_metrics_per_batch: False # aggregate metrics over all batches
max_wandb_images: 4

data:
target: ldwma.lightning.data_modules.rtx.RTXDataModule
params:
batch_size: 16
target_height: 320
target_width: 512
dataset_name: "fractal20220817_data"
shuffle_buffer: 100
traj_len: 16
deterministic: True # slow but deterministic data loading for eval
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: avid_145M
group: avid_145M

num_batches: 64

base_config_file: configs/train/dynamicrafter_512.yaml
action_config_file: configs/train/act_cond_diffusion_145M.yaml

base_model_checkpoint: /host_home/avid/dynamicrafter_512/model.ckpt
action_model_checkpoint: /host_home/avid/avid_145M/model.ckpt

target_module: ldwma.models.avid.AVIDAdapter
adapter_params:
condition_adapter_on_base_outputs: True
learnt_mask: True
init_mask_bias: 0.0

ddim_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
ddim_eta: 1.0
verbose: True

logger:
target: pytorch_lightning.loggers.WandbLogger
params:
save_dir: /host_home/wandb/
offline: False
project: avid-eval
entity: causica

video_logger_callback:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 1
reset_metrics_per_batch: False # aggregate metrics over all batches
max_wandb_images: 4

data:
target: ldwma.lightning.data_modules.rtx.RTXDataModule
params:
batch_size: 16
target_height: 320
target_width: 512
dataset_name: "fractal20220817_data"
shuffle_buffer: 100
traj_len: 16
deterministic: True # slow but deterministic data loading for eval
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: control_net_170M
group: control_net_170M

num_batches: 64

base_config_file: configs/train/control_net/dynamicrafter_512.yaml
control_config_file: configs/train/control_net/act_control_lite_170M.yaml

base_model_checkpoint: /host_home/avid/dynamicrafter_512/model.ckpt
action_model_checkpoint: /host_home/avid/control_net_170M/model.ckpt

target_module: ldwma.models.control_net.ControlNetAdapter

ddim_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
ddim_eta: 1.0
verbose: True

logger:
target: pytorch_lightning.loggers.WandbLogger
params:
save_dir: /host_home/wandb/
offline: False
project: avid-eval
entity: causica

video_logger_callback:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 1
reset_metrics_per_batch: False # aggregate metrics over all batches
max_wandb_images: 4

data:
target: ldwma.lightning.data_modules.rtx.RTXDataModule
params:
batch_size: 16
target_height: 320
target_width: 512
dataset_name: "fractal20220817_data"
shuffle_buffer: 100
traj_len: 16
deterministic: True # slow but deterministic data loading for eval
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: control_net_full
group: control_net_full

num_batches: 64

base_config_file: configs/train/control_net/dynamicrafter_512.yaml
control_config_file: configs/train/control_net/act_control.yaml

base_model_checkpoint: /host_home/avid/dynamicrafter_512/model.ckpt
action_model_checkpoint: /host_home/avid/control_net_full/model.ckpt

target_module: ldwma.models.control_net.ControlNetAdapter

ddim_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
ddim_eta: 1.0
verbose: True

logger:
target: pytorch_lightning.loggers.WandbLogger
params:
save_dir: /host_home/wandb/
offline: False
project: avid-eval
entity: causica

video_logger_callback:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 1
reset_metrics_per_batch: False # aggregate metrics over all batches
max_wandb_images: 4

data:
target: ldwma.lightning.data_modules.rtx.RTXDataModule
params:
batch_size: 16
target_height: 320
target_width: 512
dataset_name: "fractal20220817_data"
shuffle_buffer: 100
traj_len: 16
deterministic: True # slow but deterministic data loading for eval
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: dynamicrafter_pretrained
group: dynamicrafter_pretrained

num_batches: 64

model_config_file: configs/train/dynamicrafter_512.yaml

ddim_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
ddim_eta: 1.0
verbose: True

logger:
target: pytorch_lightning.loggers.WandbLogger
params:
save_dir: /host_home/wandb/
offline: False
project: avid-eval
entity: causica

video_logger_callback:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 1
reset_metrics_per_batch: False # aggregate metrics over all batches
max_wandb_images: 4

data:
target: ldwma.lightning.data_modules.rtx.RTXDataModule
params:
batch_size: 16
target_height: 320
target_width: 512
dataset_name: "fractal20220817_data"
shuffle_buffer: 100
traj_len: 16
deterministic: True # slow but deterministic data loading for eval
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: dynamicrafter_finetune
group: dynamicrafter_finetune

num_batches: 64

model_config_file: configs/train/dynamicrafter_512_action_finetune.yaml
act_cond_unet_checkpoint: /host_home/avid/dynamicrafter_finetune/model.ckpt

ddim_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 1.0 # don't use cfg within ddim class
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
ddim_eta: 1.0
verbose: True

logger:
target: pytorch_lightning.loggers.WandbLogger
params:
save_dir: /host_home/wandb/
offline: False
project: avid-eval
entity: causica

video_logger_callback:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 1
reset_metrics_per_batch: False # aggregate metrics over all batches
max_wandb_images: 4

data:
target: ldwma.lightning.data_modules.rtx.RTXDataModule
params:
batch_size: 16
target_height: 320
target_width: 512
dataset_name: "fractal20220817_data"
shuffle_buffer: 100
traj_len: 16
deterministic: True # slow but deterministic data loading for eval
Loading

0 comments on commit d4ca28a

Please sign in to comment.