Skip to content

Commit

Permalink
Fixes Imagen sampling example and updates container (#868)
Browse files Browse the repository at this point in the history
1. Fix imagen sampling loop when prompt_ct is a multiple of `batch_size
// gen_per_prompt`
2. Add comment explaining the invocation of sampling script of what
exactly is expected due to implicit checkpoint dir requirements and
quoting
3. add 2B base model generation gin configs
4. parametrize imagen sampling scripts
5. Updates the imagen image with the fix in (1); built as follows:
```
docker buildx build --push -t ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 ./JAX-Toolbox -f - <<EOF
FROM ghcr.io/nvidia/t5x:imagen-2023-10-02
COPY rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh /opt/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh
COPY rosetta/rosetta/projects/imagen/README.md /opt/rosetta/rosetta/projects/imagen/README.md
COPY rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin /opt/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin
COPY rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin /opt/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin
COPY rosetta/rosetta/projects/imagen/imagen_pipe.py /opt/rosetta/rosetta/projects/imagen/imagen_pipe.py
COPY rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub /opt/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub
COPY rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh /opt/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh
COPY rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh /opt/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh
EOF
```

---------

Signed-off-by: Terry Kong <terryk@nvidia.com>
  • Loading branch information
terrykong committed Aug 26, 2024
1 parent 400d83a commit de12e17
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ We currently support the following frameworks and models. More details about eac
| :--- | :---: | :---: | :---: |
| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` |
| [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` |
| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02` |
| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` |
| maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` |
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute ${XLA_FLAGS}"
# These XLA flags are meant to be used with the JAX version in the imagen container
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}"
24 changes: 18 additions & 6 deletions rosetta/rosetta/projects/imagen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ For maximum flexibility and low disk requirements, this repo supports a **distri
We provide [scripts](scripts) to run [interactively](scripts/singlenode_inf_train.sh) or on [SLURM](scripts/example_slurm_inf_train.sub).

### Container
We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02`.
We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3`.

We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates!
Imagen will also be available in our T5x container in future releases.
Expand All @@ -37,7 +37,7 @@ You will need to acquire the LLM checkpoint for T5 (for multimodal training) fro
**Note**: this should only be done with singlenode jobs

```bash
CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02
CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02.v3
docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bash
```

Expand Down Expand Up @@ -99,15 +99,27 @@ sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in [scripts](scripts). Prompts should be specified as in [example](../diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt)

#### Sampling 256x256 images
Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script)
Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script, e.g., [imagen_256_sample_2b.gin](configs/imagen_256_sample_2b.gin)).
```
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 GLOBAL_BATCH_SIZE=<GBS> GEN_PER_PROMPT=1 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
```

Here is an example:
```
# Note:
# - the quoting of double quotes wrapping single quotes is necessary.
# - BASE_PATH/SR1_PATH are checkpoint dirs, and are expected to contain a `checkpoint` file, e.g., the file $BASE_PATH/checkpoint should exist
# - GLOBAL_BATCH_SIZE should be set with number of GPUs in mind. For instance GLOBAL_BATCH_SIZE >= num gpus,
# to ensure at least one example is sent to each GPU.
# - Currently there is a limitation where the number of lines in PROMPT_TEXT_FILES should be divisible by the number of GPUs.
# The easiest way to ensure that is just to pad the files with dummy prompts until it is divisible
CUDA_VISIBLE_DEVICES=0,1 CFG=5.0 GLOBAL_BATCH_SIZE=4 GEN_PER_PROMPT=1 BASE_PATH='"/mnt/imagen_ckpt/checkpoint_585000"' SR1_PATH='"/mnt/sr1_ckpt/checkpoint_5000"' PROMPT_TEXT_FILES='"./rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt"' ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
```

#### Sampling 1024x1024 images
Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script).
Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script, e.g., [imagen_1024_sample_2b.gin](configs/imagen_1024_sample_2b.gin)).
```
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> SR2_PATH=<SR2_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 GLOBAL_BATCH_SIZE=<GBS> GEN_PER_PROMPT=1 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> SR2_PATH=<SR2_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh
```


Expand Down
78 changes: 78 additions & 0 deletions rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Imagen Sampling pipeline
include "rosetta/projects/imagen/configs/imagen_256_sample_2b.gin"

from __gin__ import dynamic_registration
import __main__ as sample_script
from t5x import gin_utils
from t5x import utils
from t5x import partitioning

from rosetta.projects.imagen import network_sr
from rosetta.projects.diffusion import models
from rosetta.projects.diffusion import denoisers
from rosetta.projects.diffusion import samplers
from rosetta.projects.diffusion import losses
from rosetta.projects.diffusion import augmentations

#---------------- SR1024 Model -------------------------------------------------

# ------------------- Model ----------------------------------------------------
SR1024 = @sr1024/models.DenoisingDiffusionModel()
SIGMA_DATA = 0.5
sr1024/models.DenoisingDiffusionModel:
denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser()
diffusion_loss= None
diffusion_sampler= @sr1024/samplers.EDMSampler()
optimizer_def = None

# |--- Denoiser
sr1024/denoisers.EDMTextConditionedSuperResDenoiser:
raw_model= @sr1024/network_sr.ImagenEfficientUNet()

sr1024/samplers.EDMSampler:
dim_noise_scalar = 4.

# ------------------- Network specification ------------------------------------
sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig()
sr1024/network_sr.ImagenEfficientUNetConfig:
dtype = %DTYPE
model_dim = 128
cond_dim = 1024
resblocks_per_level = (2, 4, 8, 8, 8)
width_multipliers = (1, 2, 4, 6, 6)
attn_resolutions_divs = {16: 'cross'}
mha_head_dim = 64
attn_heads = 8
resblock_activation = 'silu'
resblock_zero_out = True
resblock_scale_skip = True
dropout_rate = %DROPOUT_RATE
cond_strategy = 'shift_scale'
norm_32 = True
scale_attn_logits = True
float32_attention_logits=False
text_conditionable = True

sr1024/samplers.CFGSamplingConfig:
num_steps=30
cf_guidance_weight=0.0
cf_guidance_nulls={'text': None, 'text_mask': None}

sr1024/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

sr1024/utils.RestoreCheckpointConfig:
mode = 'specific'
dtype = 'bfloat16'

sr1024/sample_script.DiffusionModelSetupData:
model = %SR1024
sampling_cfg = @sr1024/samplers.CFGSamplingConfig()
restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)}
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'}

sample_script.sample:
sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData()
220 changes: 220 additions & 0 deletions rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Imagen Sampling pipeline
from __gin__ import dynamic_registration

import __main__ as sample_script
from t5x import gin_utils
from t5x import utils
from t5x import partitioning

SAVE_DIR='generations'
PROMPT_TEXT_FILE='custom_text.txt'
GLOBAL_BATCH_SIZE=32
MAX_GENERATE=50000000
GEN_PER_PROMPT=2
NOISE_COND_AUG=0.002

TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim
TXT_SEQLEN=(1, 128, )
TXT_SEQLEN_SINGLE=128
DTYPE='bfloat16'
DROPOUT_RATE=0
RESUME_FROM=0 #Sampling count to resume from
#---------------- Base Model -------------------------------------------------
from rosetta.projects.imagen import network
from rosetta.projects.imagen import network_sr
from rosetta.projects.diffusion import models
from rosetta.projects.diffusion import denoisers
from rosetta.projects.diffusion import samplers
from rosetta.projects.diffusion import losses
from rosetta.projects.diffusion import augmentations

# ------------------- Model ----------------------------------------------------
BASE = @base_model/models.DenoisingDiffusionModel()
base_model/models.DenoisingDiffusionModel:
denoiser= @base_model/denoisers.EDMTextConditionedDenoiser()
diffusion_loss = None
diffusion_sampler= @base_model/samplers.EDMSampler()
optimizer_def = None

# |--- Denoiser
base_model/denoisers.EDMTextConditionedDenoiser:
raw_model= @base_model/network.ImagenUNet()

# ------------------- Network specification ------------------------------------
base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig()
base_model/network.DiffusionConfig:
dtype = %DTYPE
model_dim = 512
attn_cond_dim = 2048
cond_dim = 2048
resblocks_per_level = 3
width_multipliers = (1, 2, 3, 4)
attn_resolutions = (32, 16, 8)
mha_head_dim = 64
attn_heads = 8
resblock_activation = 'silu'
dropout_rate = %DROPOUT_RATE
upsample_mode = 'shuffle'
downsample_mode = 'shuffle'
spatial_skip = False
cond_strategy = 'shift_scale'
norm_32 = True
scale_attn_logits = True
float32_attention_logits = False
text_conditionable = True


BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig()
base_model/samplers.CFGSamplingConfig:
num_steps=50
cf_guidance_weight=5.00
cf_guidance_nulls=None

base_model/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

base_model/utils.RestoreCheckpointConfig:
mode = 'specific'
dtype = 'bfloat16'

base_model/sample_script.DiffusionModelSetupData:
model = %BASE
sampling_cfg = @base_model/samplers.CFGSamplingConfig()
restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN}
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'}

#---------------- SR256 Model -------------------------------------------------

# ------------------- Model ----------------------------------------------------
SR256 = @sr256/models.DenoisingDiffusionModel()
SIGMA_DATA = 0.5
sr256/models.DenoisingDiffusionModel:
denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser()
diffusion_loss= None
diffusion_sampler= @sr256/samplers.EDMSampler()
optimizer_def = None

# |--- Denoiser
sr256/denoisers.EDMTextConditionedSuperResDenoiser:
raw_model= @sr256/network_sr.ImagenEfficientUNet()

sr256/samplers.EDMSampler:
dim_noise_scalar = 4.

# ------------------- Network specification ------------------------------------
sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig()
sr256/network_sr.ImagenEfficientUNetConfig:
dtype = %DTYPE
model_dim = 128
cond_dim = 512
attn_cond_dim = 1024
resblocks_per_level = (2, 4, 8, 8, 2)
width_multipliers = (1, 2, 4, 8, 8)
attn_resolutions_divs = {8: 'fused', 16: 'fused'}
mha_head_dim = 64
attn_heads = 8
resblock_activation = 'silu'
resblock_zero_out = True
resblock_scale_skip = True
dropout_rate = %DROPOUT_RATE
cond_strategy = 'shift_scale'
norm_32 = True
scale_attn_logits = True
float32_attention_logits=False
text_conditionable = True

sr256/samplers.CFGSamplingConfig:
num_steps=50
cf_guidance_weight=4
cf_guidance_nulls={'text': None, 'text_mask': None}

sr256/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

sr256/utils.RestoreCheckpointConfig:
mode = 'specific'
dtype = 'bfloat16'

sr256/sample_script.DiffusionModelSetupData:
model = %SR256
sampling_cfg = @sr256/samplers.CFGSamplingConfig()
restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)}
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'}

#---------------- Text Model -------------------------------------------------
import seqio
from rosetta.projects.inference_serving.t5 import network as t5x_network
from rosetta.projects.inference_serving.t5 import models as t5x_models

# =====================================
# === T5 Encoder only configuration ===
# =====================================
T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl"
BATCH_SIZE = 256 # Will be overridden
SEQ_LEN = 128 # MAX seqlen

# Vocabulary
VOCABULARY = @seqio.SentencePieceVocabulary()
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use.

# --------------- Model ------------------
TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel()
text_enc/t5x_models.EncoderOnlyModel:
module = @t5x_network.TransformerEncoderOnly()
input_vocabulary = %VOCABULARY
output_vocabulary = %VOCABULARY
optimizer_def = None
z_loss = 0.0001
label_smoothing = 0.0
loss_normalizing_factor = None

# -------- Network specification ---------
t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config()
t5x_network.T5Config:
vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
dtype = 'bfloat16'
emb_dim = 4096
num_heads = 64
num_encoder_layers = 24
num_decoder_layers = 0
head_dim = 64
mlp_dim = 10240
mlp_activations = ('gelu', 'linear')
dropout_rate = 0.0

text_enc/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

text_enc/utils.RestoreCheckpointConfig:
path = %T5_CHECKPOINT_PATH
mode = 'specific'
dtype = 'bfloat16'

text_enc/sample_script.setup_text_enc:
model=%TEXT_ENC
restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig()
partitioner=@text_enc/partitioning.PjitPartitioner()
batch_size=1
seq_len=%TXT_SEQLEN_SINGLE
vocab = %VOCABULARY

sample_script.sample:
base_setupdata = @base_model/sample_script.DiffusionModelSetupData()
sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData()
sr1024_setupdata = None
out_dir = %SAVE_DIR
gen_per_prompt = %GEN_PER_PROMPT
prompt_file = %PROMPT_TEXT_FILE
batch_size = %GLOBAL_BATCH_SIZE
max_images = %MAX_GENERATE
text_enc_infer = @text_enc/sample_script.setup_text_enc()
noise_conditioning_aug = %NOISE_COND_AUG
resume_from = %RESUME_FROM
Loading

0 comments on commit de12e17

Please sign in to comment.