From ce21ffbe5e7d62196ba8755457f88e2b2cddecf4 Mon Sep 17 00:00:00 2001 From: Zeeshan Patel Date: Sun, 13 Oct 2024 01:26:33 -0700 Subject: [PATCH] Diffusion Transformer Training Pipeline (#10843) * diffusion training Signed-off-by: Zeeshan Patel * fixing issues with data module Signed-off-by: Zeeshan Patel * added dit llama support, cleaned up dit code Signed-off-by: Zeeshan Patel * fixed code formatting Signed-off-by: Zeeshan Patel * added dit llama models Signed-off-by: Zeeshan Patel --------- Signed-off-by: Zeeshan Patel --- .../data/diffusion_energon_datamodule.py | 19 +- .../diffusion/data/diffusion_taskencoder.py | 120 +++- nemo/collections/diffusion/models/__init__.py | 13 + .../diffusion/models/dit/__init__.py | 13 + .../diffusion/models/dit/dit_embeddings.py | 161 ++++++ .../diffusion/models/dit/dit_layer_spec.py | 532 ++++++++++++++++++ .../diffusion/models/dit/dit_model.py | 359 ++++++++++++ .../diffusion/models/dit_llama/__init__.py | 13 + .../models/dit_llama/dit_llama_layer_spec.py | 173 ++++++ .../models/dit_llama/dit_llama_model.py | 60 ++ nemo/collections/diffusion/models/model.py | 423 ++++++++++++++ nemo/collections/diffusion/readme.rst | 190 +++++++ .../collections/diffusion/sampler/__init__.py | 13 + .../diffusion/sampler/batch_ops.py | 104 ++++ .../diffusion/sampler/context_parallel.py | 51 ++ .../diffusion/sampler/edm/__init__.py | 13 + nemo/collections/diffusion/sampler/edm/edm.py | 135 +++++ .../diffusion/sampler/edm/edm_pipeline.py | 434 ++++++++++++++ nemo/collections/diffusion/scripts/train.sh | 29 + nemo/collections/diffusion/train.py | 201 +++++++ nemo/collections/diffusion/vae/__init__.py | 13 + .../diffusion/vae/diffusers_vae.py | 34 ++ 22 files changed, 3088 insertions(+), 15 deletions(-) create mode 100644 nemo/collections/diffusion/models/__init__.py create mode 100644 nemo/collections/diffusion/models/dit/__init__.py create mode 100644 nemo/collections/diffusion/models/dit/dit_embeddings.py create mode 100644 nemo/collections/diffusion/models/dit/dit_layer_spec.py create mode 100644 nemo/collections/diffusion/models/dit/dit_model.py create mode 100644 nemo/collections/diffusion/models/dit_llama/__init__.py create mode 100644 nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py create mode 100644 nemo/collections/diffusion/models/dit_llama/dit_llama_model.py create mode 100644 nemo/collections/diffusion/models/model.py create mode 100644 nemo/collections/diffusion/readme.rst create mode 100644 nemo/collections/diffusion/sampler/__init__.py create mode 100644 nemo/collections/diffusion/sampler/batch_ops.py create mode 100644 nemo/collections/diffusion/sampler/context_parallel.py create mode 100644 nemo/collections/diffusion/sampler/edm/__init__.py create mode 100644 nemo/collections/diffusion/sampler/edm/edm.py create mode 100644 nemo/collections/diffusion/sampler/edm/edm_pipeline.py create mode 100644 nemo/collections/diffusion/scripts/train.sh create mode 100644 nemo/collections/diffusion/train.py create mode 100644 nemo/collections/diffusion/vae/__init__.py create mode 100644 nemo/collections/diffusion/vae/diffusers_vae.py diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index fe17b4eecb5f..f18c828d9d45 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal + +import logging +from typing import Any, Dict, Literal from megatron.energon import DefaultTaskEncoder, get_train_dataset from pytorch_lightning.utilities.types import EVAL_DATALOADERS @@ -127,3 +129,18 @@ def val_dataloader(self) -> EVAL_DATALOADERS: if self.use_train_split_for_val: return self.train_dataloader() return super().val_dataloader() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + try: + super().load_state_dict(state_dict) + except Exception as e: + logging.warning(f"datamodule.load_state_dict failed {e}") diff --git a/nemo/collections/diffusion/data/diffusion_taskencoder.py b/nemo/collections/diffusion/data/diffusion_taskencoder.py index 3285c63b2461..57e4e4ec8673 100644 --- a/nemo/collections/diffusion/data/diffusion_taskencoder.py +++ b/nemo/collections/diffusion/data/diffusion_taskencoder.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import warnings import torch import torch.nn.functional as F +from einops import rearrange from megatron.core import parallel_state from megatron.energon import DefaultTaskEncoder, SkipSample from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys @@ -66,10 +69,22 @@ class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin): Cooker(cook), ] - def __init__(self, *args, max_frames: int = None, text_embedding_padding_size: int = 512, **kwargs): + def __init__( + self, + *args, + max_frames: int = None, + text_embedding_padding_size: int = 512, + seq_length: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + **kwargs, + ): super().__init__(*args, **kwargs) self.max_frames = max_frames self.text_embedding_padding_size = text_embedding_padding_size + self.seq_length = seq_length + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal def encode_sample(self, sample: dict) -> dict: video_latent = sample['pth'] @@ -80,9 +95,19 @@ def encode_sample(self, sample: dict) -> dict: raise SkipSample() info = sample['json'] - _, T, H, W = video_latent.shape + C, T, H, W = video_latent.shape + seq_len = ( + video_latent.shape[-1] + * video_latent.shape[-2] + * video_latent.shape[-3] + // self.patch_spatial**2 + // self.patch_temporal + ) is_image = T == 1 + if seq_len > self.seq_length: + raise SkipSample() + if self.max_frames is not None: video_latent = video_latent[:, : self.max_frames, :, :] @@ -90,11 +115,16 @@ def encode_sample(self, sample: dict) -> dict: if parallel_state.get_context_parallel_world_size() > 1: tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 if (T * H * W) % tpcp_size != 0: - print(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') + warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') raise SkipSample() - seq_len = video_latent.shape[-1] * video_latent.shape[-2] * video_latent.shape[-3] - loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + video_latent = rearrange( + video_latent, + 'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)', + ph=self.patch_spatial, + pw=self.patch_spatial, + pt=self.patch_temporal, + ) if is_image: t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16) @@ -102,20 +132,82 @@ def encode_sample(self, sample: dict) -> dict: t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16) t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] - t5_text_embeddings = F.pad( - t5_text_embeddings, - ( - 0, - 0, - 0, - self.text_embedding_padding_size - t5_text_embeddings_seq_length % self.text_embedding_padding_size, - ), - ) + if t5_text_embeddings_seq_length > self.text_embedding_padding_size: + t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] + else: + t5_text_embeddings = F.pad( + t5_text_embeddings, + ( + 0, + 0, + 0, + self.text_embedding_padding_size - t5_text_embeddings_seq_length, + ), + ) t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) + if is_image: + h, w = info['image_height'], info['image_width'] + fps = torch.tensor([30] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) + else: + h, w = info['height'], info['width'] + fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16) + image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) + + pos_ids = rearrange( + pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), + 'T H W d -> (T H W) d', + ) + + if self.seq_length is not None: + pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) + loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) + loss_mask[:seq_len] = 1 + video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) + else: + loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + return dict( video=video_latent, t5_text_embeddings=t5_text_embeddings, t5_text_mask=t5_text_mask, + image_size=image_size, + fps=fps, + num_frames=num_frames, loss_mask=loss_mask, + seq_len_q=torch.tensor(seq_len, dtype=torch.int32), + seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32), + pos_ids=pos_ids, + latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) + + +class PosID3D: + def __init__(self, *, max_t=32, max_h=128, max_w=128): + self.max_t = max_t + self.max_h = max_h + self.max_w = max_w + self.generate_pos_id() + + def generate_pos_id(self): + self.grid = torch.stack( + torch.meshgrid( + torch.arange(self.max_t, device='cpu'), + torch.arange(self.max_h, device='cpu'), + torch.arange(self.max_w, device='cpu'), + ), + dim=-1, + ) + + def get_pos_id_3d(self, *, t, h, w): + if t > self.max_t or h > self.max_h or w > self.max_w: + self.max_t = max(self.max_t, t) + self.max_h = max(self.max_h, h) + self.max_w = max(self.max_w, w) + self.generate_pos_id() + return self.grid[:t, :h, :w] + + +pos_id_3d = PosID3D() diff --git a/nemo/collections/diffusion/models/__init__.py b/nemo/collections/diffusion/models/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/models/dit/__init__.py b/nemo/collections/diffusion/models/dit/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/models/dit/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/models/dit/dit_embeddings.py b/nemo/collections/diffusion/models/dit/dit_embeddings.py new file mode 100644 index 000000000000..ec8d095cbbd4 --- /dev/null +++ b/nemo/collections/diffusion/models/dit/dit_embeddings.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Dict, Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.embeddings import TimestepEmbedding, get_3d_sincos_pos_embed +from einops import rearrange +from einops.layers.torch import Rearrange +from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rotary_pos_embedding import get_pos_emb_on_this_cp_rank +from megatron.core.transformer.module import MegatronModule +from torch import nn + + +class ParallelTimestepEmbedding(TimestepEmbedding): + """ + ParallelTimestepEmbedding is a subclass of TimestepEmbedding that initializes + the embedding layers with an optional random seed for syncronization. + + Args: + in_channels (int): Number of input channels. + time_embed_dim (int): Dimension of the time embedding. + seed (int, optional): Random seed for initializing the embedding layers. + If None, no specific seed is set. + + Attributes: + linear_1 (nn.Module): First linear layer for the embedding. + linear_2 (nn.Module): Second linear layer for the embedding. + + Methods: + __init__(in_channels, time_embed_dim, seed=None): Initializes the embedding layers. + """ + + def __init__(self, in_channels: int, time_embed_dim: int, seed=None): + super().__init__(in_channels=in_channels, time_embed_dim=time_embed_dim) + if seed is not None: + with torch.random.fork_rng(): + torch.manual_seed(seed) + self.linear_1.reset_parameters() + self.linear_2.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the positional embeddings for the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, H, W, C). + + Returns: + torch.Tensor: Positional embeddings of shape (B, T, H, W, C). + """ + return super().forward(x.to(torch.bfloat16, non_blocking=True)) + + +def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim): + """ + Adjusts the positional embeddings tensor to the current context parallel rank. + + Args: + pos_emb (torch.Tensor): The positional embeddings tensor. + seq_dim (int): The sequence dimension index in the positional embeddings tensor. + + Returns: + torch.Tensor: The adjusted positional embeddings tensor for the current context parallel rank. + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor([cp_rank], device="cpu", pin_memory=True).cuda(non_blocking=True) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]) + pos_emb = pos_emb.index_select(seq_dim, cp_idx) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) + return pos_emb + + +class SinCosPosEmb3D(MegatronModule): + """ + SinCosPosEmb3D is a 3D sine-cosine positional embedding module. + + Args: + model_channels (int): Number of channels in the model. + h (int): Length of the height dimension. + w (int): Length of the width dimension. + t (int): Length of the temporal dimension. + spatial_interpolation_scale (float, optional): Scale factor for spatial interpolation. Default is 1.0. + temporal_interpolation_scale (float, optional): Scale factor for temporal interpolation. Default is 1.0. + + Methods: + forward(pos_ids: torch.Tensor) -> torch.Tensor: + Computes the positional embeddings for the input tensor. + + Args: + pos_ids (torch.Tensor): Input tensor of shape (B S 3). + + Returns: + torch.Tensor: Positional embeddings of shape (B S D). + """ + + def __init__( + self, + config, + h: int, + w: int, + t: int, + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + ): + super().__init__(config=config) + self.h = h + self.w = w + self.t = t + # h w t + param = get_3d_sincos_pos_embed( + config.hidden_size, [h, w], t, spatial_interpolation_scale, temporal_interpolation_scale + ) + param = rearrange(param, "t hw c -> (t hw) c") + self.pos_embedding = torch.nn.Embedding(param.shape[0], config.hidden_size) + self.pos_embedding.weight = torch.nn.Parameter(torch.tensor(param), requires_grad=False) + + def forward(self, pos_ids: torch.Tensor): + # pos_ids: t h w + pos_id = pos_ids[..., 0] * self.h * self.w + pos_ids[..., 1] * self.w + pos_ids[..., 2] + return self.pos_embedding(pos_id) + + +class FactorizedLearnable3DEmbedding(MegatronModule): + def __init__( + self, + config, + t: int, + h: int, + w: int, + **kwargs, + ): + super().__init__(config=config) + self.emb_t = torch.nn.Embedding(t, config.hidden_size) + self.emb_h = torch.nn.Embedding(h, config.hidden_size) + self.emb_w = torch.nn.Embedding(w, config.hidden_size) + + if config.perform_initialization: + config.init_method(self.emb_t.weight) + config.init_method(self.emb_h.weight) + config.init_method(self.emb_w.weight) + + def forward(self, pos_ids: torch.Tensor): + return self.emb_t(pos_ids[..., 0]) + self.emb_h(pos_ids[..., 1]) + self.emb_w(pos_ids[..., 2]) diff --git a/nemo/collections/diffusion/models/dit/dit_layer_spec.py b/nemo/collections/diffusion/models/dit/dit_layer_spec.py new file mode 100644 index 000000000000..672dcff3ba00 --- /dev/null +++ b/nemo/collections/diffusion/models/dit/dit_layer_spec.py @@ -0,0 +1,532 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from dataclasses import dataclass +from typing import Literal, Union + +import torch +import torch.nn as nn +from einops import rearrange +from megatron.core.jit import jit_fuser +from megatron.core.transformer.attention import ( + CrossAttention, + CrossAttentionSubmodules, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TENorm, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_block import TransformerConfig +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor + + +@dataclass +class DiTWithAdaLNSubmodules(TransformerLayerSubmodules): + temporal_self_attention: Union[ModuleSpec, type] = IdentityOp + full_self_attention: Union[ModuleSpec, type] = IdentityOp + + +@dataclass +class STDiTWithAdaLNSubmodules(TransformerLayerSubmodules): + spatial_self_attention: Union[ModuleSpec, type] = IdentityOp + temporal_self_attention: Union[ModuleSpec, type] = IdentityOp + full_self_attention: Union[ModuleSpec, type] = IdentityOp + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, config, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class AdaLN(MegatronModule): + """ + Adaptive Layer Normalization Module for DiT. + """ + + def __init__(self, config: TransformerConfig, n_adaln_chunks=9, norm=nn.LayerNorm): + super().__init__(config) + if norm == TENorm: + self.ln = norm(config, config.hidden_size, config.layernorm_epsilon) + else: + self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon) + self.n_adaln_chunks = n_adaln_chunks + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(config.hidden_size, self.n_adaln_chunks * config.hidden_size, bias=False) + ) + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + + setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel) + + def forward(self, timestep_emb): + return self.adaLN_modulation(timestep_emb).chunk(self.n_adaln_chunks, dim=-1) + + @jit_fuser + def modulate(self, x, shift, scale): + return x * (1 + scale) + shift + + @jit_fuser + def scale_add(self, residual, x, gate): + return residual + gate * x + + @jit_fuser + def modulated_layernorm(self, x, shift, scale): + # Optional Input Layer norm + input_layernorm_output = self.ln(x).type_as(x) + + # DiT block specific + return self.modulate(input_layernorm_output, shift, scale) + + # @jit_fuser + def scaled_modulated_layernorm(self, residual, x, gate, shift, scale): + hidden_states = self.scale_add(residual, x, gate) + shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale) + return hidden_states, shifted_pre_mlp_layernorm_output + + +class STDiTLayerWithAdaLN(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + Spatial-Temporal DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute", + ): + def _replace_no_cp_submodules(submodules): + modified_submods = copy.deepcopy(submodules) + modified_submods.cross_attention = IdentityOp + modified_submods.spatial_self_attention = IdentityOp + return modified_submods + + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. + modified_submods = _replace_no_cp_submodules(submodules) + super().__init__( + config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout + ) + + # Override Spatial Self Attention and Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + sa_cp_override_config = copy.deepcopy(config) + sa_cp_override_config.context_parallel_size = 1 + sa_cp_override_config.tp_comm_overlap = False + self.spatial_self_attention = build_module( + submodules.spatial_self_attention, config=sa_cp_override_config, layer_number=layer_number + ) + self.cross_attention = build_module( + submodules.cross_attention, + config=sa_cp_override_config, + layer_number=layer_number, + ) + + self.temporal_self_attention = build_module( + submodules.temporal_self_attention, + config=self.config, + layer_number=layer_number, + ) + + self.full_self_attention = build_module( + submodules.full_self_attention, + config=self.config, + layer_number=layer_number, + ) + + self.adaLN = AdaLN(config=self.config, n_adaln_chunks=3) + + def forward( + self, + hidden_states, + attention_mask, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + ): + # timestep embedding + timestep_emb = attention_mask + + # ******************************************** spatial self attention ****************************************************** + + shift_sa, scale_sa, gate_sa = self.adaLN(timestep_emb) + + # adaLN with scale + shift + pre_spatial_attn_layernorm_output_ada = self.adaLN.modulated_layernorm( + hidden_states, shift=shift_sa, scale=scale_sa + ) + + attention_output, _ = self.spatial_self_attention( + pre_spatial_attn_layernorm_output_ada, + attention_mask=None, + # packed_seq_params=packed_seq_params['self_attention'], + ) + + # ******************************************** full self attention ************************************************* + + shift_full, scale_full, gate_full = self.adaLN(timestep_emb) + + # adaLN with scale + shift + hidden_states, pre_full_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( + residual=hidden_states, + x=attention_output, + gate=gate_sa, + shift=shift_full, + scale=scale_full, + ) + + attention_output, _ = self.full_self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + # packed_seq_params=packed_seq_params['self_attention'], + ) + + # ******************************************** cross attention ***************************************************** + + shift_ca, scale_ca, gate_ca = self.adaLN(timestep_emb) + + # adaLN with scale + shift + hidden_states, pre_cross_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( + residual=hidden_states, + x=attention_output, + gate=gate_full, + shift=shift_ca, + scale=scale_ca, + ) + + attention_output, _ = self.cross_attention( + pre_cross_attn_layernorm_output_ada, + attention_mask=context_mask, + key_value_states=context, + # packed_seq_params=packed_seq_params['cross_attention'], + ) + + # ******************************************** temporal self attention ********************************************* + + shift_ta, scale_ta, gate_ta = self.adaLN(timestep_emb) + + hidden_states, pre_temporal_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( + residual=hidden_states, + x=attention_output, + gate=gate_ca, + shift=shift_ta, + scale=scale_ta, + ) + + attention_output, _ = self.temporal_self_attention( + pre_temporal_attn_layernorm_output_ada, + attention_mask=None, + # packed_seq_params=packed_seq_params['self_attention'], + ) + + # ******************************************** mlp ***************************************************************** + + shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + + hidden_states, pre_mlp_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( + residual=hidden_states, + x=attention_output, + gate=gate_ta, + shift=shift_mlp, + scale=scale_mlp, + ) + + mlp_output, _ = self.mlp(pre_mlp_layernorm_output_ada) + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + + return output, context + + +class DiTLayerWithAdaLN(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute", + ): + def _replace_no_cp_submodules(submodules): + modified_submods = copy.deepcopy(submodules) + modified_submods.cross_attention = IdentityOp + # modified_submods.temporal_self_attention = IdentityOp + return modified_submods + + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. + modified_submods = _replace_no_cp_submodules(submodules) + super().__init__( + config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout + ) + + # Override Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + if submodules.cross_attention != IdentityOp: + cp_override_config = copy.deepcopy(config) + cp_override_config.context_parallel_size = 1 + cp_override_config.tp_comm_overlap = False + self.cross_attention = build_module( + submodules.cross_attention, + config=cp_override_config, + layer_number=layer_number, + ) + else: + self.cross_attention = None + + self.full_self_attention = build_module( + submodules.full_self_attention, + config=self.config, + layer_number=layer_number, + ) + + self.adaLN = AdaLN(config=self.config, n_adaln_chunks=9 if self.cross_attention else 6) + + def forward( + self, + hidden_states, + attention_mask, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + ): + # timestep embedding + timestep_emb = attention_mask + + # ******************************************** full self attention ****************************************************** + if self.cross_attention: + shift_full, scale_full, gate_full, shift_ca, scale_ca, gate_ca, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN(timestep_emb) + ) + else: + shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + + # adaLN with scale + shift + pre_full_attn_layernorm_output_ada = self.adaLN.modulated_layernorm( + hidden_states, shift=shift_full, scale=scale_full + ) + + attention_output, _ = self.full_self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + packed_seq_params=None if packed_seq_params is None else packed_seq_params['self_attention'], + ) + + if self.cross_attention: + # ******************************************** cross attention ****************************************************** + # adaLN with scale + shift + hidden_states, pre_cross_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( + residual=hidden_states, + x=attention_output, + gate=gate_full, + shift=shift_ca, + scale=scale_ca, + ) + + attention_output, _ = self.cross_attention( + pre_cross_attn_layernorm_output_ada, + attention_mask=context_mask, + key_value_states=context, + packed_seq_params=None if packed_seq_params is None else packed_seq_params['cross_attention'], + ) + + # ******************************************** mlp ****************************************************** + hidden_states, pre_mlp_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( + residual=hidden_states, + x=attention_output, + gate=gate_ca if self.cross_attention else gate_full, + shift=shift_mlp, + scale=scale_mlp, + ) + + mlp_output, _ = self.mlp(pre_mlp_layernorm_output_ada) + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + + return output, context + + +def get_stdit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=STDiTLayerWithAdaLN, + submodules=STDiTWithAdaLNSubmodules( + spatial_self_attention=ModuleSpec( + module=SelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + temporal_self_attention=ModuleSpec( + module=SelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + full_self_attention=ModuleSpec( + module=SelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=CrossAttention, + params=params, + submodules=CrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=DiTLayerWithAdaLN, + submodules=DiTWithAdaLNSubmodules( + full_self_attention=ModuleSpec( + module=SelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=RMSNorm, + k_layernorm=RMSNorm, + ), + ), + cross_attention=ModuleSpec( + module=CrossAttention, + params=params, + submodules=CrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=RMSNorm, + k_layernorm=RMSNorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_official_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.no_mask} + return ModuleSpec( + module=DiTLayerWithAdaLN, + submodules=DiTWithAdaLNSubmodules( + full_self_attention=ModuleSpec( + module=SelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/nemo/collections/diffusion/models/dit/dit_model.py b/nemo/collections/diffusion/models/dit/dit_model.py new file mode 100644 index 000000000000..0c1c1abc82f2 --- /dev/null +++ b/nemo/collections/diffusion/models/dit/dit_model.py @@ -0,0 +1,359 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, Literal, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.embeddings import Timesteps +from einops import rearrange, repeat +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_sharded_tensor_for_checkpoint +from torch import Tensor + +from nemo.collections.diffusion.models.dit import dit_embeddings +from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding +from nemo.collections.diffusion.models.dit.dit_layer_spec import ( + get_dit_adaln_block_with_transformer_engine_spec as DiTLayerWithAdaLNspec, +) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class RMSNorm(nn.Module): + def __init__(self, channel: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(channel)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, spatial_patch_size, temporal_patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)) + + def forward(self, x_BT_HW_D, emb_B_D): + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + T = x_BT_HW_D.shape[0] // emb_B_D.shape[0] + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + +class DiTCrossAttentionModel(VisionModule): + """ + DiTCrossAttentionModel is a VisionModule that implements a DiT model with a cross-attention block. + Attributes: + config (TransformerConfig): Configuration for the transformer. + pre_process (bool): Whether to apply pre-processing steps. + post_process (bool): Whether to apply post-processing steps. + fp16_lm_cross_entropy (bool): Whether to use fp16 for cross-entropy loss. + parallel_output (bool): Whether to use parallel output. + position_embedding_type (Literal["learned_absolute", "rope"]): Type of position embedding. + max_img_h (int): Maximum image height. + max_img_w (int): Maximum image width. + max_frames (int): Maximum number of frames. + patch_spatial (int): Spatial patch size. + patch_temporal (int): Temporal patch size. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + transformer_decoder_layer_spec (DiTLayerWithAdaLNspec): Specification for the transformer decoder layer. + add_encoder (bool): Whether to add an encoder. + add_decoder (bool): Whether to add a decoder. + share_embeddings_and_output_weights (bool): Whether to share embeddings and output weights. + concat_padding_mask (bool): Whether to concatenate padding mask. + pos_emb_cls (str): Class of position embedding. + model_type (ModelType): Type of the model. + decoder (TransformerBlock): Transformer decoder block. + t_embedder (torch.nn.Sequential): Time embedding layer. + x_embedder (nn.Conv3d): Convolutional layer for input embedding. + pos_embedder (dit_embeddings.SinCosPosEmb3D): Position embedding layer. + final_layer_linear (torch.nn.Linear): Final linear layer. + affline_norm (RMSNorm): Affine normalization layer. + Methods: + forward(x: Tensor, timesteps: Tensor, crossattn_emb: Tensor, packed_seq_params: PackedSeqParams = None, pos_ids: Tensor = None, **kwargs) -> Tensor: + Forward pass of the model. + set_input_tensor(input_tensor: Tensor) -> None: + Sets input tensor to the model. + sharded_state_dict(prefix: str = 'module.', sharded_offsets: tuple = (), metadata: Optional[Dict] = None) -> ShardedStateDict: + Sharded state dict implementation for backward-compatibility. + tie_embeddings_weights_state_dict(tensor, sharded_state_dict: ShardedStateDict, output_layer_weight_key: str, first_stage_word_emb_key: str) -> None: + Ties the embedding and output weights in a given sharded state dict. + """ + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + position_embedding_type: Literal["learned_absolute", "rope"] = "rope", + max_img_h: int = 80, + max_img_w: int = 80, + max_frames: int = 34, + patch_spatial: int = 1, + patch_temporal: int = 1, + in_channels: int = 16, + out_channels: int = 16, + transformer_decoder_layer_spec=DiTLayerWithAdaLNspec, + pos_embedder=dit_embeddings.SinCosPosEmb3D, + **kwargs, + ): + super(DiTCrossAttentionModel, self).__init__(config=config) + + self.config: TransformerConfig = config + + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = True + self.add_decoder = True + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.position_embedding_type = position_embedding_type + self.share_embeddings_and_output_weights = False + self.concat_padding_mask = True + self.pos_emb_cls = 'sincos' + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + # Transformer decoder + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=False, + post_layer_norm=False, + ) + + self.t_embedder = torch.nn.Sequential( + Timesteps(self.config.hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0), + dit_embeddings.ParallelTimestepEmbedding(self.config.hidden_size, self.config.hidden_size, seed=1234), + ) + + if self.pre_process: + self.x_embedder = torch.nn.Linear(in_channels * patch_spatial**2, self.config.hidden_size) + + self.pos_embedder = pos_embedder( + config, + t=max_frames // patch_temporal, + h=max_img_h // patch_spatial, + w=max_img_w // patch_spatial, + ) + self.fps_embedder = nn.Sequential( + Timesteps(num_channels=256, flip_sin_to_cos=False, downscale_freq_shift=1), + ParallelTimestepEmbedding(256, 256), + ) + + if self.post_process: + self.final_layer_linear = torch.nn.Linear( + self.config.hidden_size, + patch_spatial**2 * patch_temporal * out_channels, + ) + + self.affline_norm = RMSNorm(self.config.hidden_size) + + def forward( + self, + x: Tensor, + timesteps: Tensor, + crossattn_emb: Tensor, + packed_seq_params: PackedSeqParams = None, + pos_ids: Tensor = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x (Tensor): vae encoded data (b s c) + encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder + inference_params (InferenceParams): relevant arguments for inferencing + + Returns: + Tensor: loss tensor + """ + B = x.shape[0] + fps = kwargs.get( + 'fps', + torch.tensor( + [ + 30, + ] + * B, + dtype=torch.bfloat16, + ), + ).view(-1) + if self.pre_process: + # transpose to match + x_B_S_D = self.x_embedder(x) + if isinstance(self.pos_embedder, dit_embeddings.SinCosPosEmb3D): + pos_emb = None + x_B_S_D += self.pos_embedder(pos_ids) + else: + pos_emb = self.pos_embedder(pos_ids) + pos_emb = rearrange(pos_emb, "B S D -> S B D") + x_S_B_D = rearrange(x_B_S_D, "B S D -> S B D") + else: + # intermediate stage of pipeline + x_S_B_D = None ### should it take encoder_hidden_states + + timesteps_B_D = self.t_embedder(timesteps.flatten()).to(torch.bfloat16) # (b d_text_embedding) + + affline_emb_B_D = timesteps_B_D + fps_B_D = self.fps_embedder(fps) + fps_B_D = nn.functional.pad(fps_B_D, (0, self.config.hidden_size - fps_B_D.shape[1])) + affline_emb_B_D += fps_B_D + + crossattn_emb = rearrange(crossattn_emb, 'B S D -> S B D') + + if self.config.sequence_parallel: + if self.pre_process: + x_S_B_D = tensor_parallel.scatter_to_sequence_parallel_region(x_S_B_D) + crossattn_emb = tensor_parallel.scatter_to_sequence_parallel_region(crossattn_emb) + # `scatter_to_sequence_parallel_region` returns a view, which prevents + # the original tensor from being garbage collected. Clone to facilitate GC. + # Has a small runtime cost (~0.5%). + if self.config.clone_scatter_output_in_embedding: + if self.pre_process: + x_S_B_D = x_S_B_D.clone() + crossattn_emb = crossattn_emb.clone() + + x_S_B_D = self.decoder( + hidden_states=x_S_B_D, + attention_mask=affline_emb_B_D, + context=crossattn_emb, + context_mask=None, + rotary_pos_emb=pos_emb, + packed_seq_params=packed_seq_params, + ) + + if not self.post_process: + return x_S_B_D + + if self.config.sequence_parallel: + x_S_B_D = tensor_parallel.gather_from_sequence_parallel_region(x_S_B_D) + + x_S_B_D = self.final_layer_linear(x_S_B_D) + return rearrange(x_S_B_D, "S B D -> B S D") + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.decoder.set_input_tensor(input_tensor[0]) + + def sharded_state_dict( + self, prefix: str = 'module.', sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + for param_name, param in self.t_embedder.named_parameters(): + weight_key = f'{prefix}t_embedder.{param_name}' + self.tie_embeddings_weights_state_dict(param, sharded_state_dict, weight_key, weight_key) + + for param_name, param in self.affline_norm.named_parameters(): + weight_key = f'{prefix}affline_norm.{param_name}' + self.tie_embeddings_weights_state_dict(param, sharded_state_dict, weight_key, weight_key) + + return sharded_state_dict + + def tie_embeddings_weights_state_dict( + self, + tensor, + sharded_state_dict: ShardedStateDict, + output_layer_weight_key: str, + first_stage_word_emb_key: str, + ) -> None: + """Ties the embedding and output weights in a given sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + output_layer_weight_key (str): key of the output layer weight in the state dict. + This entry will be replaced with a tied version + first_stage_word_emb_key (str): this must be the same as the + ShardedTensor.key of the first stage word embeddings. + + Returns: None, acts in-place + """ + if self.pre_process and parallel_state.get_tensor_model_parallel_rank() == 0: + # Output layer is equivalent to the embedding already + return + + # Replace the default output layer with a one sharing the weights with the embedding + del sharded_state_dict[output_layer_weight_key] + last_stage_word_emb_replica_id = ( + 0, # copy of first stage embedding + parallel_state.get_tensor_model_parallel_rank() + + parallel_state.get_pipeline_model_parallel_rank() + * parallel_state.get_pipeline_model_parallel_world_size(), + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[output_layer_weight_key] = make_sharded_tensor_for_checkpoint( + tensor=tensor, + key=first_stage_word_emb_key, + replica_id=last_stage_word_emb_replica_id, + allow_shape_mismatch=False, + ) diff --git a/nemo/collections/diffusion/models/dit_llama/__init__.py b/nemo/collections/diffusion/models/dit_llama/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/models/dit_llama/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py b/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py new file mode 100644 index 000000000000..80bed5878e1b --- /dev/null +++ b/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Literal + +from megatron.core.transformer.attention import ( + CrossAttention, + CrossAttentionSubmodules, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_block import TransformerConfig +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor + +from nemo.collections.diffusion.models.dit.dit_layer_spec import AdaLN + + +class MoviegGenLayer(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute", + ): + def _replace_no_cp_submodules(submodules): + modified_submods = copy.deepcopy(submodules) + modified_submods.cross_attention = IdentityOp + # modified_submods.temporal_self_attention = IdentityOp + return modified_submods + + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. + modified_submods = _replace_no_cp_submodules(submodules) + super().__init__( + config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout + ) + + # Override Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + cp_override_config = copy.deepcopy(config) + cp_override_config.context_parallel_size = 1 + cp_override_config.tp_comm_overlap = False + self.cross_attention = build_module( + submodules.cross_attention, + config=cp_override_config, + layer_number=layer_number, + ) + + self.adaLN = AdaLN(config=self.config, n_adaln_chunks=6) # , norm=TENorm) + + def forward( + self, + hidden_states, + attention_mask, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + ): + # timestep embedding + timestep_emb = attention_mask + factorized_pos_emb = rotary_pos_emb + hidden_states = hidden_states + factorized_pos_emb + + # ******************************************** full self attention ****************************************************** + shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + + # adaLN with scale + shift + pre_full_attn_layernorm_output_ada = self.adaLN.modulated_layernorm( + hidden_states, shift=shift_full, scale=scale_full + ) + + attention_output, _ = self.self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + packed_seq_params=None if packed_seq_params is None else packed_seq_params['self_attention'], + ) + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) + + # ******************************************** cross attention ****************************************************** + attention_output, _ = self.cross_attention( + hidden_states, + attention_mask=context_mask, + key_value_states=context, + packed_seq_params=None if packed_seq_params is None else packed_seq_params['cross_attention'], + ) + + # ******************************************** mlp ****************************************************** + pre_mlp_layernorm_output_ada = self.adaLN.modulated_layernorm( + attention_output, shift=shift_mlp, scale=scale_mlp + ) + + mlp_output, _ = self.mlp(pre_mlp_layernorm_output_ada) + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + + return output, context + + +def get_dit_llama_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=MoviegGenLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + cross_attention=ModuleSpec( + module=CrossAttention, + params=params, + submodules=CrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py b/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py new file mode 100644 index 000000000000..bfa79e366cac --- /dev/null +++ b/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Literal + +from megatron.core.transformer.transformer_config import TransformerConfig + +from nemo.collections.diffusion.models.dit import dit_embeddings +from nemo.collections.diffusion.models.dit.dit_model import DiTCrossAttentionModel +from nemo.collections.diffusion.models.dit_llama.dit_llama_layer_spec import get_dit_llama_spec + + +class DiTLlamaModel(DiTCrossAttentionModel): + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + position_embedding_type: Literal["learned_absolute", "rope"] = "rope", + max_img_h: int = 80, + max_img_w: int = 80, + max_frames: int = 34, + patch_spatial: int = 1, + patch_temporal: int = 1, + in_channels: int = 16, + out_channels: int = 16, + **kwargs, + ): + super().__init__( + config=config, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + parallel_output=parallel_output, + position_embedding_type=position_embedding_type, + max_img_h=max_img_h, + max_img_w=max_img_w, + max_frames=max_frames, + patch_spatial=patch_spatial, + patch_temporal=patch_temporal, + in_channels=in_channels, + out_channels=out_channels, + transformer_decoder_layer_spec=get_dit_llama_spec, + pos_embedder=dit_embeddings.FactorizedLearnable3DEmbedding, + **kwargs, + ) diff --git a/nemo/collections/diffusion/models/model.py b/nemo/collections/diffusion/models/model.py new file mode 100644 index 000000000000..8cc6be860585 --- /dev/null +++ b/nemo/collections/diffusion/models/model.py @@ -0,0 +1,423 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import warnings +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import wandb +from einops import rearrange +from megatron.core import parallel_state +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn +from typing_extensions import override + +from nemo.collections.diffusion.models.dit_llama.dit_llama_model import DiTLlamaModel +from nemo.collections.diffusion.sampler.edm.edm_pipeline import EDMPipeline +from nemo.collections.llm.gpt.model.base import GPTModel +from nemo.lightning import io +from nemo.lightning.megatron_parallel import MaskedTokenLossReduction, MegatronLossReduction +from nemo.lightning.pytorch.optim import OptimizerModule + +from .dit.dit_model import DiTCrossAttentionModel + + +def dit_forward_step(model, batch) -> torch.Tensor: + return model(**batch) + + +def dit_data_step(module, dataloader_iter): + batch = next(dataloader_iter)[0] + batch = get_batch_on_this_cp_rank(batch) + batch = {k: v.to(device='cuda', non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + + cu_seqlens = batch['seq_len_q'].cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + cu_seqlens_kv = batch['seq_len_kv'].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + + batch['packed_seq_params'] = { + 'self_attention': PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format='sbhd', + ), + 'cross_attention': PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens_kv, + qkv_format='sbhd', + ), + } + + return batch + + +def get_batch_on_this_cp_rank(data: Dict): + """Split the data for context parallelism.""" + from megatron.core import mpu + + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + + t = 16 + if cp_size > 1: + assert t % cp_size == 0, "t must divisibly by cp_size" + num_valid_tokens_in_ub = None + if 'loss_mask' in data and data['loss_mask'] is not None: + num_valid_tokens_in_ub = data['loss_mask'].sum() + + for key, value in data.items(): + if (value is not None) and (key in ['video', 'video_latent', 'noise_latent', 'pos_ids']): + if len(value.shape) > 5: + value = value.squeeze(0) + B, C, T, H, W = value.shape + # TODO: sequence packing + data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() + loss_mask = data["loss_mask"] + data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ + :, cp_rank, ... + ].contiguous() + data['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub + return data + + +@dataclass +class DiTConfig(TransformerConfig, io.IOMixin): + """ + Config for DiT-S model + """ + + crossattn_emb_size: int = 1024 + add_bias_linear: bool = False + gated_linear_unit: bool = False + + num_layers: int = 12 + hidden_size: int = 384 + max_img_h: int = 80 + max_img_w: int = 80 + max_frames: int = 34 + patch_spatial: int = 2 + num_attention_heads: int = 6 + layernorm_epsilon = 1e-6 + normalization = "RMSNorm" + add_bias_linear = False + qk_layernorm_per_head = True + layernorm_zero_centered_gamma = False + + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + share_embeddings_and_output_weights: bool = True + + # max_position_embeddings: int = 5400 + hidden_dropout: float = 0 + attention_dropout: float = 0 + + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 + + vae_module: str = 'nemo.collections.diffusion.vae.diffusers_vae.AutoencoderKLVAE' + vae_path: str = None + sigma_data: float = 0.5 + + in_channels: int = 16 + + data_step_fn = dit_data_step + forward_step_fn = dit_forward_step + + @override + def configure_model(self, tokenizer=None) -> DiTCrossAttentionModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert ( + self.num_layers // p_size + ) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages." + + if isinstance(self, DiTLlama30BConfig): + model = DiTLlamaModel + else: + model = DiTCrossAttentionModel + return model( + self, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + max_img_h=self.max_img_h, + max_img_w=self.max_img_w, + max_frames=self.max_frames, + patch_spatial=self.patch_spatial, + ) + + def configure_vae(self): + return dynamic_import(self.vae_module)(self.vae_path) + + +@dataclass +class DiTBConfig(DiTConfig): + num_layers: int = 12 + hidden_size: int = 768 + num_attention_heads: int = 12 + + +@dataclass +class DiTLConfig(DiTConfig): + num_layers: int = 24 + hidden_size: int = 1024 + num_attention_heads: int = 16 + + +@dataclass +class DiTXLConfig(DiTConfig): + num_layers: int = 28 + hidden_size: int = 1152 + num_attention_heads: int = 16 + + +@dataclass +class DiT7BConfig(DiTConfig): + num_layers: int = 32 + hidden_size: int = 3072 + num_attention_heads: int = 24 + + +@dataclass +class DiTLlama30BConfig(DiTConfig): + num_layers: int = 48 + hidden_size: int = 6144 + ffn_hidden_size: int = 16384 + num_attention_heads: int = 48 + num_query_groups: int = 8 + gated_linear_unit: int = True + bias_activation_fusion: int = True + activation_func: Callable = F.silu + normalization: str = "RMSNorm" + layernorm_epsilon: float = 1e-5 + max_frames: int = 128 + max_img_h: int = 240 + max_img_w: int = 240 + patch_spatial: int = 2 + + init_method_std: float = 0.01 + add_bias_linear: bool = False + seq_length: int = 256 + + bias_activation_fusion: bool = True + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + bias_dropout_fusion: bool = True + + +@dataclass +class DiTLlama5BConfig(DiTLlama30BConfig): + num_layers: int = 32 + hidden_size: int = 3072 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 24 + + +class DiTModel(GPTModel): + def __init__( + self, + config: Optional[DiTConfig] = None, + optim: Optional[OptimizerModule] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + tokenizer: Optional[Any] = None, + ): + super().__init__(config or DiTConfig(), optim=optim, model_transform=model_transform) + + self.vae = None + + self._training_loss_reduction = None + self._validation_loss_reduction = None + + self.diffusion_pipeline = EDMPipeline(net=self, sigma_data=self.config.sigma_data) + + self._noise_generator = None + self.seed = 42 + + self.vae = None + + def data_step(self, dataloader_iter) -> Dict[str, Any]: + return self.config.data_step_fn(dataloader_iter) + + def forward(self, *args, **kwargs): + return self.module.forward(*args, **kwargs) + + def forward_step(self, batch) -> torch.Tensor: + if parallel_state.is_pipeline_last_stage(): + output_batch, loss = self.diffusion_pipeline.training_step(batch, 0) + loss = torch.mean(loss, dim=-1) + return loss + else: + output_tensor = self.diffusion_pipeline.training_step(batch, 0) + return output_tensor + + def training_step(self, batch, batch_idx=None) -> torch.Tensor: + # In mcore the loss-function is part of the forward-pass (when labels are provided) + return self.forward_step(batch) + + def on_validation_start(self): + if self.vae is None: + if self.config.vae_path is None: + warnings.warn('vae_path not specified skipping validation') + return None + self.vae = self.config.configure_vae() + self.vae.to('cuda') + + def on_validation_end(self): + if self.vae is not None: + self.vae.to('cpu') + + def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + # In mcore the loss-function is part of the forward-pass (when labels are provided) + state_shape = batch['video'].shape + sample = self.diffusion_pipeline.generate_samples_from_batch( + batch, + guidance=7, + state_shape=state_shape, + num_steps=35, + is_negative_prompt=True if 'neg_t5_text_embeddings' in batch else False, + ) + + # TODO visualize more than 1 sample + sample = sample[0, None] + C, T, H, W = batch['latent_shape'][0] + seq_len_q = batch['seq_len_q'][0] + + sample = rearrange( + sample[:, :seq_len_q], + 'B (T H W) (ph pw pt C) -> B C (T pt) (H ph) (W pw)', + ph=self.config.patch_spatial, + pw=self.config.patch_spatial, + C=C, + T=T, + H=H // self.config.patch_spatial, + W=W // self.config.patch_spatial, + ) + + video = (1.0 + self.vae.decode(sample / self.config.sigma_data)).clamp(0, 2) / 2 # [B, 3, T, H, W] + + video = (video * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) + + T = video.shape[2] + if T == 1: + image = rearrange(video, 'b c t h w -> (b t h) w c') + result = image + else: + # result = wandb.Video(video, fps=float(batch['fps'])) # (batch, time, channel, height width) + result = video + + # wandb is on the last rank for megatron, first rank for nemo + wandb_rank = 0 + + if parallel_state.get_data_parallel_src_rank() == wandb_rank: + if torch.distributed.get_rank() == wandb_rank: + gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())] + else: + gather_list = None + torch.distributed.gather_object( + result, gather_list, wandb_rank, group=parallel_state.get_data_parallel_group() + ) + if gather_list is not None: + videos = [] + for video in gather_list: + if len(video.shape) == 3: + videos.append(wandb.Image(video)) + else: + videos.append(wandb.Video(video, fps=30)) + wandb.log({'prediction': videos}, step=self.global_step) + + return None + + @property + def training_loss_reduction(self) -> MaskedTokenLossReduction: + if not self._training_loss_reduction: + self._training_loss_reduction = MaskedTokenLossReduction() + + return self._training_loss_reduction + + @property + def validation_loss_reduction(self) -> MaskedTokenLossReduction: + if not self._validation_loss_reduction: + self._validation_loss_reduction = DummyLossReduction() + + return self._validation_loss_reduction + + def on_validation_model_zero_grad(self) -> None: + ''' + Small hack to avoid first validation on resume. + This will NOT work if the gradient accumulation step should be performed at this point. + https://github.com/Lightning-AI/pytorch-lightning/discussions/18110 + ''' + super().on_validation_model_zero_grad() + if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True): + self.trainer.sanity_checking = True + self._restarting_skip_val_flag = False + + +class DummyLossReduction(MegatronLossReduction): + def __init__(self, validation_step: bool = False, val_drop_last: bool = True) -> None: + super().__init__() + self.validation_step = validation_step + self.val_drop_last = val_drop_last + + def forward( + self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + return torch.tensor(0.0, device=torch.cuda.current_device()), { + "avg": torch.tensor(0.0, device=torch.cuda.current_device()) + } + + def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: + return torch.tensor(0.0, device=torch.cuda.current_device()) + + +def dynamic_import(full_path): + """ + Dynamically import a class or function from a given full path. + + :param full_path: The full path to the class or function (e.g., "package.module.ClassName") + :return: The imported class or function + :raises ImportError: If the module or attribute cannot be imported + :raises AttributeError: If the attribute does not exist in the module + """ + try: + # Split the full path into module path and attribute name + module_path, attribute_name = full_path.rsplit('.', 1) + except ValueError as e: + raise ImportError( + f"Invalid full path '{full_path}'. It should contain both module and attribute names." + ) from e + + # Import the module + try: + module = importlib.import_module(module_path) + except ImportError as e: + raise ImportError(f"Cannot import module '{module_path}'.") from e + + # Retrieve the attribute from the module + try: + attribute = getattr(module, attribute_name) + except AttributeError as e: + raise AttributeError(f"Module '{module_path}' does not have an attribute '{attribute_name}'.") from e + + return attribute diff --git a/nemo/collections/diffusion/readme.rst b/nemo/collections/diffusion/readme.rst new file mode 100644 index 000000000000..871527948708 --- /dev/null +++ b/nemo/collections/diffusion/readme.rst @@ -0,0 +1,190 @@ +Diffusion Training Framework +============= + +Overview +-------- + +The NeMo Diffusion Training Framework provides a scalable training platform for diffusion models with transformer backbones. Our new features streamline the training process, allowing developers to efficiently train state-of-the-art models with ease. + + +Some of the features we currently support include: + +- Energon Dataloader for Webscale Dataloading +- Model and Data Parallelism +- Model Architectures: DiT 30B parameters or even more + + +Features Status +--------------- + +We support image diffusion training. Video training incoming. + + ++---------------------------+------------------+ +| Parallelism | Status | ++===========================+==================+ +| FSDP | ✅ Supported | ++---------------------------+------------------+ +| CP+TP+SP+distopt | ✅ Supported | ++---------------------------+------------------+ +| CP+TP+SP+PP+distopt | ✅ Supported | ++---------------------------+------------------+ +| CP+TP+SP+FSDP | 🕒 Coming Soon | ++---------------------------+------------------+ + + +**Legend:** +- **FSDP**: Fully Sharded Data Parallelism +- **CP**: Context Parallelism +- **TP**: Tensor Parallelism +- **SP**: Sequence Parallelism +- **PP**: Pipeline Parallelism +- **distop**: mcore distributed optmizer + ++--------------+-------------------+-----------------+ +| Model Size | Modality | Status | ++==============+===================+=================+ +| DiT 30B+ | 256px image | ✅ Supported | ++--------------+-------------------+-----------------+ +| DiT 30B+ | 256px image+video | 🕒 Coming Soon | ++--------------+-------------------+-----------------+ +| DiT 30B+ | 768px image+video | 🕒 Coming Soon | ++--------------+-------------------+-----------------+ + + +Energon Dataloader for Webscale Dataloading +------------------------------------------- + +Webscale Dataloading +^^^^^^^^^^^^^^^^^^^^ + +Megatron-Energon is an optimized multi-modal dataloader for large-scale deep learning with Megatron. Energon allows for distributed loading of large training training data for multi-modal model training. Energon allows for blending many datasets together and distributing the dataloading workflow across multiple cluster nodes/processes while ensuring reproducibility and resumability. + +Dataloader Checkpointing +^^^^^^^^^^^^^^^^^^^^^^^^ + +One of Energon's key features is its ability to save and restore its state. This functionality is crucial for long-running training processes, making the dataloader robust and recoverable after interruptions. By allowing checkpointing of the dataloader status, Energon ensures that training can be resumed from where it left off, saving time and computational resources in case of unexpected shutdowns or planned pauses in the training process. This makes it especially useful for large scale training as it requires several training jobs for end-to-end training. + +Parallel Configuration +^^^^^^^^^^^^^^^^^^^^^^ + +Energon's architecture allows it to efficiently distribute data across multiple processing units, ensuring that each GPU or node receives a balanced workload. This parallelization not only increases the overall throughput of data processing but also helps in maintaining high utilization of available computational resources. + + +Mixed Image-Video Training (comming soon) +------------------------------ + +Our dataloader provides support for mixed image-video training by using the NeMo packed sequence feature to pack together images and videos of varying length into the same microbatch. The sequence packing mechanism uses the THD attention kernel, which allows us to increase the model FLOPs utilization (MFU) and efficiently process data with varying length. + + +.. image:: assets/mixed_training.png + :alt: Mixed image-video dataloading strategy + :width: 300px + :align: center + +Model and Data Parallelism +-------------------------- +NeMo provides support for training models using tensor parallelism, sequence parallelism, pipeline parallelism, and context parallelism. To support pipeline parallelism with conditional diffusion training, we duplicate the conditional embeddings across the pipeline stages, and perform an all-reduce during the backward pass. This approach uses more compute, but it has a lower communication cost than sending the conditional embeddings through different pipeline stages. + +.. image:: assets/pipeline_conditioning.png + :alt: Conditioning mechanism for pipeline parallelism + :width: 300px + :align: center + +Model Architectures +------------------- + +DiT +^^^ +We implement an efficient version of the diffusion transformer (DiT) [1]_. Our DiT is slightly modified from the original paper as we use cross attention and adaptive layernorm together in the same architecture. We also use a QK-layernorm for training stability. Our framework allows for customizing the DiT architecture while maintaining its scalability, enabling training large DiT models on long sequence lengths. + + + +Data preparation +-------------------------- + +We expect data to be in this webdataset format. For more information about webdataset and energon dataset, please refer to https://github.com/NVIDIA/Megatron-Energon + +Here we demonstrate a step by step example of how to prepare a dummy image dataset. + +.. code-block:: bash + + torchrun --nproc-per-node 2 nemo/collections/diffusion/data/prepare_energon_dataset.py --factory prepare_dummy_image_dataset + +this will generate a folder a tar files. .pth contains image/video latent representations encode by image/video tokenizer, .json contains metadata including text caption, resolution, aspection ratio, and .pickle contains text embeddings encoded by language model like T5. + +.. code-block:: bash + + shard_000.tar + ├── samples/sample_0000.pth + ├── samples/sample_0000.pickle + ├── samples/sample_0000.json + ├── samples/sample_0001.pth + ├── samples/sample_0001.pickle + ├── samples/sample_0001.json + └── ... + shard_001.tar + +The following is a sample command to prepare prepare webdataset into energon dataset: + +.. code-block:: bash + + # energon prepare . --num-workers 192 + Found 369057 tar files in total. The first and last ones are: + - 0.tar + - 99999.tar + If you want to exclude some of them, cancel with ctrl+c and specify an exclude filter in the command line. + Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1": 1,0,0 + Indexing shards [####################################] 369057/369057 + Sample 0, keys: + - .json + - .pickle + - .pth + Sample 1, keys: + - .json + - .pickle + - .pth + Found the following part types in the dataset: .json, .pth, .pickle + Do you want to create a dataset.yaml interactively? [Y/n]: Y + The following dataset classes are available: + 0. CaptioningWebdataset + 1. CrudeWebdataset + 2. ImageClassificationWebdataset + 3. ImageWebdataset + 4. InterleavedWebdataset + 5. MultiChoiceVQAWebdataset + 6. OCRWebdataset + 7. SimilarityInterleavedWebdataset + 8. TextWebdataset + 9. VQAOCRWebdataset + 10. VQAWebdataset + 11. VidQAWebdataset + Please enter a number to choose a class: 1 + The dataset you selected uses the following sample type: + + class CrudeSample(dict): + """Generic sample type to be processed later.""" + + CrudeWebdataset does not need a field map. You will need to provide a `Cooker` for your dataset samples in your `TaskEncoder`. + Furthermore, you might want to add `subflavors` in your meta dataset specification. + +training +-------------------------- + +To launch training on one node + +.. code-block:: bash + + torchrun --nproc-per-node 8 nemo/collections/diffusion/train.py --yes --factory pretrain_xl + +To launch training on multiple nodes using Slurm + +.. code-block:: bash + + sbatch nemo/collections/diffusion/scripts/train.sh --factory pretrain_xl + + +Citations +--------- + +.. [1] William Peebles and Saining Xie, "Scalable Diffusion Models with Transformers," *arXiv preprint arXiv:2212.09748*, 2022. \ No newline at end of file diff --git a/nemo/collections/diffusion/sampler/__init__.py b/nemo/collections/diffusion/sampler/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/sampler/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/sampler/batch_ops.py b/nemo/collections/diffusion/sampler/batch_ops.py new file mode 100644 index 000000000000..956dfbee36e5 --- /dev/null +++ b/nemo/collections/diffusion/sampler/batch_ops.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import Tensor + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + """ + Broadcasts two tensors to have the same shape by adding singleton dimensions where necessary. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + tuple[Tensor, Tensor]: A tuple containing the two tensors with broadcasted shapes. + + Raises: + AssertionError: If the dimensions of the tensors do not match at any axis within their common dimensions. + """ + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_add(x: Tensor, y: Tensor) -> Tensor: + """ + Adds two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The element-wise sum of the input tensors after broadcasting. + """ + x, y = common_broadcast(x, y) + return x + y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + """ + Multiplies two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The element-wise product of the input tensors after broadcasting. + """ + x, y = common_broadcast(x, y) + return x * y + + +def batch_sub(x: Tensor, y: Tensor) -> Tensor: + """ + Subtracts two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The result of element-wise subtraction of the input tensors. + """ + x, y = common_broadcast(x, y) + return x - y + + +def batch_div(x: Tensor, y: Tensor) -> Tensor: + """ + Divides two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The result of element-wise division of `x` by `y` after broadcasting. + """ + x, y = common_broadcast(x, y) + return x / y diff --git a/nemo/collections/diffusion/sampler/context_parallel.py b/nemo/collections/diffusion/sampler/context_parallel.py new file mode 100644 index 000000000000..f389b7ba2656 --- /dev/null +++ b/nemo/collections/diffusion/sampler/context_parallel.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup, all_gather, get_world_size + + +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenates tensors from multiple processes along a specified dimension. + + This function gathers tensors from all processes in the given process group + and concatenates them along the specified dimension. + + Args: + x (Tensor): The input tensor to be gathered and concatenated. + seq_dim (int): The dimension along which to concatenate the gathered tensors. + cp_group (ProcessGroup): The process group containing all the processes involved in the gathering. + + Returns: + Tensor: A tensor resulting from the concatenation of tensors from all processes. + + Raises: + RuntimeError: If the gathering of tensors fails. + """ + # Number of processes in the group + world_size = get_world_size(cp_group) + + # List to hold tensors from each rank + gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] + + # Attempt to gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Gathering failed: {e}") + + # Concatenate tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) diff --git a/nemo/collections/diffusion/sampler/edm/__init__.py b/nemo/collections/diffusion/sampler/edm/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/sampler/edm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/sampler/edm/edm.py b/nemo/collections/diffusion/sampler/edm/edm.py new file mode 100644 index 000000000000..eb47728af40a --- /dev/null +++ b/nemo/collections/diffusion/sampler/edm/edm.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from statistics import NormalDist +from typing import Callable, Tuple + +import numpy as np +import torch +from torch import nn +from tqdm import tqdm + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EDMSDE: + def __init__( + self, + p_mean: float = -1.2, + p_std: float = 1.2, + sigma_max: float = 80.0, + sigma_min: float = 0.002, + ): + self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self._generator = np.random + + def sample_t(self, batch_size: int) -> torch.Tensor: + cdf_vals = self._generator.uniform(size=(batch_size)) + samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + return torch.exp(log_sigma) + + def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return x0, sigma + + +class EDMSampler(nn.Module): + """ + Elucidating the Design Space of Diffusion-Based Generative Models (EDM) + # https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/generate.py#L25 + + Attributes: + None + + Methods: + forward(x0_fn: Callable, x_sigma_max: torch.Tensor, num_steps: int = 35, sigma_min: float = 0.002, + sigma_max: float = 80, rho: float = 7, S_churn: float = 0, S_min: float = 0, + S_max: float = float("inf"), S_noise: float = 1) -> torch.Tensor: + Performs the forward pass for the EDM sampling process. + + Parameters: + x0_fn (Callable): A function that takes in a tensor and returns a denoised tensor. + x_sigma_max (torch.Tensor): The initial noise level tensor. + num_steps (int, optional): The number of sampling steps. Default is 35. + sigma_min (float, optional): The minimum noise level. Default is 0.002. + sigma_max (float, optional): The maximum noise level. Default is 80. + rho (float, optional): The rho parameter for time step discretization. Default is 7. + S_churn (float, optional): The churn parameter for noise increase. Default is 0. + S_min (float, optional): The minimum value for the churn parameter. Default is 0. + S_max (float, optional): The maximum value for the churn parameter. Default is float("inf"). + S_noise (float, optional): The noise scale for the churn parameter. Default is 1. + + Returns: + torch.Tensor: The sampled tensor after the EDM process. + """ + + @torch.no_grad() + def forward( + self, + x0_fn: Callable, + x_sigma_max: torch.Tensor, + num_steps: int = 35, + sigma_min: float = 0.002, + sigma_max: float = 80, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, + ) -> torch.Tensor: + # Time step discretization. + in_dtype = x_sigma_max.dtype + _ones = torch.ones(x_sigma_max.shape[0], dtype=in_dtype, device=x_sigma_max.device) + step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_sigma_max.device) + t_steps = ( + sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = x_sigma_max.to(torch.float64) + for i, (t_cur, t_next) in enumerate( + tqdm(zip(t_steps[:-1], t_steps[1:], strict=False), total=len(t_steps) - 1) + ): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = t_cur + gamma * t_cur + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur) + + # Euler step. + denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next.to(in_dtype) diff --git a/nemo/collections/diffusion/sampler/edm/edm_pipeline.py b/nemo/collections/diffusion/sampler/edm/edm_pipeline.py new file mode 100644 index 000000000000..6e1be1f6f2a6 --- /dev/null +++ b/nemo/collections/diffusion/sampler/edm/edm_pipeline.py @@ -0,0 +1,434 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.distributed +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from nemo.collections.diffusion.sampler.batch_ops import batch_mul +from nemo.collections.diffusion.sampler.context_parallel import cat_outputs_cp +from nemo.collections.diffusion.sampler.edm.edm import EDMSDE, EDMSampler, EDMScaling + + +class EDMPipeline: + """ + EDMPipeline is a class that implements a diffusion model pipeline for video generation. It includes methods for + initializing the pipeline, encoding and decoding video data, performing training steps, denoising, and generating + samples. + Attributes: + p_mean: Mean for SDE process. + p_std: Standard deviation for SDE process. + sigma_max: Maximum noise level. + sigma_min: Minimum noise level. + _noise_generator: Generator for noise. + _noise_level_generator: Generator for noise levels. + sde: SDE process. + sampler: Sampler for the diffusion model. + scaling: Scaling for EDM. + input_data_key: Key for input video data. + input_image_key: Key for input image data. + tensor_kwargs: Tensor keyword arguments. + loss_reduce: Method for reducing loss. + loss_scale: Scale factor for loss. + aesthetic_finetuning: Aesthetic finetuning parameter. + camera_sample_weight: Camera sample weight parameter. + loss_mask_enabled: Flag for enabling loss mask. + Methods: + noise_level_generator: Returns the noise level generator. + _initialize_generators: Initializes noise and noise-level generators. + encode: Encodes input tensor using the video tokenizer. + decode: Decodes latent tensor using video tokenizer. + training_step: Performs a single training step for the diffusion model. + denoise: Performs denoising on the input noise data, noise level, and condition. + compute_loss_with_epsilon_and_sigma: Computes the loss for training. + get_per_sigma_loss_weights: Returns loss weights per sigma noise level. + get_condition_uncondition: Returns conditioning and unconditioning for classifier-free guidance. + get_x0_fn_from_batch: Creates a function to generate denoised predictions with the sampler. + generate_samples_from_batch: Generates samples based on input data batch. + _normalize_video_databatch_inplace: Normalizes video data in-place on a CUDA device to [-1, 1]. + draw_training_sigma_and_epsilon: Draws training noise (epsilon) and noise levels (sigma). + random_dropout_input: Applies random dropout to the input tensor. + get_data_and_condition: Retrieves data and conditioning for model input. + """ + + def __init__( + self, + net, + vae=None, + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + sigma_data=0.5, + seed=1234, + ): + """ + Initializes the EDM pipeline with the given parameters. + + Args: + net: The DiT model. + vae: The Video Tokenizer (optional). + p_mean (float): Mean for the SDE. + p_std (float): Standard deviation for the SDE. + sigma_max (float): Maximum sigma value for the SDE. + sigma_min (float): Minimum sigma value for the SDE. + sigma_data (float): Sigma value for EDM scaling. + seed (int): Random seed for reproducibility. + + Attributes: + vae: The Video Tokenizer. + net: The DiT model. + p_mean (float): Mean for the SDE. + p_std (float): Standard deviation for the SDE. + sigma_max (float): Maximum sigma value for the SDE. + sigma_min (float): Minimum sigma value for the SDE. + sigma_data (float): Sigma value for EDM scaling. + seed (int): Random seed for reproducibility. + _noise_generator: Placeholder for noise generator. + _noise_level_generator: Placeholder for noise level generator. + sde: Instance of EDMSDE initialized with p_mean, p_std, sigma_max, and sigma_min. + sampler: Instance of EDMSampler. + scaling: Instance of EDMScaling initialized with sigma_data. + input_data_key (str): Key for input data. + input_image_key (str): Key for input images. + tensor_kwargs (dict): Tensor keyword arguments for device and dtype. + loss_reduce (str): Method to reduce loss ('mean' or other). + loss_scale (float): Scale factor for loss. + """ + self.vae = vae + self.net = net + + self.p_mean = p_mean + self.p_std = p_std + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.sigma_data = sigma_data + + self.seed = seed + self._noise_generator = None + self._noise_level_generator = None + + self.sde = EDMSDE(p_mean, p_std, sigma_max, sigma_min) + self.sampler = EDMSampler() + self.scaling = EDMScaling(sigma_data) + + self.input_data_key = 'video' + self.input_image_key = 'images_1024' + self.tensor_kwargs = {"device": "cuda", "dtype": torch.bfloat16} + self.loss_reduce = 'mean' + self.loss_scale = 1.0 + + @property + def noise_level_generator(self): + """ + Generates noise levels for the EDM pipeline. + + Returns: + Callable: A function or generator that produces noise levels. + """ + return self._noise_level_generator + + def _initialize_generators(self): + """ + Initializes the random number generators for noise and noise level. + + This method sets up two generators: + 1. A PyTorch generator for noise, seeded with a combination of the base seed and the data parallel rank. + 2. A NumPy generator for noise levels, seeded similarly but without considering context parallel rank. + + Returns: + None + """ + noise_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=True) + noise_level_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=False) + self._noise_generator = torch.Generator(device='cuda') + self._noise_generator.manual_seed(noise_seed) + self._noise_level_generator = np.random.default_rng(noise_level_seed) + self.sde._generator = self._noise_level_generator + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step for the diffusion model. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + + Returns: + A tuple with the output batch and the computed loss. + """ + # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. + x0_from_data_batch, x0, condition = self.get_data_and_condition(data_batch) + + # Sample pertubation noise levels and N(0, 1) noises + sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition) + + if parallel_state.is_pipeline_last_stage(): + output_batch, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + + return output_batch, edm_loss + else: + net_output = self.compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + return net_output + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: dict[str, torch.Tensor]): + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (dict[str, torch.Tensor]): conditional information + + Returns: + Predicted clean data (x0) and noise (eps_pred). + """ + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition, + ) + + if not parallel_state.is_pipeline_last_stage(): + return net_output + + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + return x0_pred, eps_pred + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: dict[str, torch.Tensor], + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + """ + Computes the loss for training. + + Args: + data_batch: Batch of input data. + x0_from_data_batch: Raw input tensor. + x0: Latent tensor. + condition: Conditional input data. + epsilon: Noise tensor. + sigma: Noise level tensor. + + Returns: + The computed loss. + """ + # Get the mean and stand deviation of the marginal probability distribution. + mean, std = self.sde.marginal_prob(x0, sigma) + # Generate noisy observations + xt = mean + batch_mul(std, epsilon) # corrupted data + + if parallel_state.is_pipeline_last_stage(): + # make prediction + x0_pred, eps_pred = self.denoise(xt, sigma, condition) + # loss weights for different noise levels + weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma) + pred_mse = (x0 - x0_pred) ** 2 + edm_loss = batch_mul(pred_mse, weights_per_sigma) + + output_batch = { + "x0": x0, + "xt": xt, + "sigma": sigma, + "weights_per_sigma": weights_per_sigma, + "condition": condition, + "model_pred": {"x0_pred": x0_pred, "eps_pred": eps_pred}, + "mse_loss": pred_mse.mean(), + "edm_loss": edm_loss.mean(), + } + return output_batch, pred_mse, edm_loss + else: + # make prediction + x0_pred = self.denoise(xt, sigma, condition) + return x0_pred.contiguous() + + def get_per_sigma_loss_weights(self, sigma: torch.Tensor): + """ + Args: + sigma (tensor): noise level + + Returns: + loss weights per sigma noise level + """ + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + def get_condition_uncondition(self, data_batch: Dict): + """Returns conditioning and unconditioning for classifier-free guidance.""" + _, _, condition = self.get_data_and_condition(data_batch, dropout_rate=0.0) + + if 'neg_t5_text_embeddings' in data_batch: + data_batch['t5_text_embeddings'] = data_batch['neg_t5_text_embeddings'] + data_batch["t5_text_mask"] = data_batch["neg_t5_text_mask"] + _, _, uncondition = self.get_data_and_condition(data_batch, dropout_rate=1.0) + else: + _, _, uncondition = self.get_data_and_condition(data_batch, dropout_rate=1.0) + + return condition, uncondition + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Creates a function to generate denoised predictions with the sampler. + + Args: + data_batch: Batch of input data. + guidance: Guidance scale factor. + is_negative_prompt: Whether to use negative prompts. + + Returns: + A callable to predict clean data (x0). + """ + condition, uncondition = self.get_condition_uncondition(data_batch) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0, _ = self.denoise(noise_x, sigma, condition) + uncond_x0, _ = self.denoise(noise_x, sigma, uncondition) + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + state_shape: Tuple | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + ) -> Tensor: + """ + Generates samples based on input data batch. + + Args: + data_batch: Batch of input data. + guidance: Guidance scale factor. + state_shape: Shape of the state. + is_negative_prompt: Whether to use negative prompts. + num_steps: Number of steps for sampling. + solver_option: SDE Solver option. + + Returns: + Generated samples from diffusion model. + """ + cp_enabled = parallel_state.get_context_parallel_world_size() > 1 + + if self._noise_generator is None: + self._initialize_generators() + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + + state_shape = list(state_shape) + state_shape[1] //= parallel_state.get_context_parallel_world_size() + x_sigma_max = ( + torch.randn(state_shape, **self.tensor_kwargs, generator=self._noise_generator) * self.sde.sigma_max + ) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + + if cp_enabled: + cp_group = parallel_state.get_context_parallel_group() + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=cp_group) + + return samples + + def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor: + """ + Draws training noise (epsilon) and noise levels (sigma). + + Args: + x0_size: Shape of the input tensor. + condition: Conditional input (unused). + + Returns: + Noise level (sigma) and noise (epsilon). + """ + del condition + batch_size = x0_size[0] + if self._noise_generator is None: + self._initialize_generators() + epsilon = torch.randn(x0_size, **self.tensor_kwargs, generator=self._noise_generator) + return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon + + def random_dropout_input(self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None) -> torch.Tensor: + """ + Applies random dropout to the input tensor. + + Args: + in_tensor: Input tensor. + dropout_rate: Dropout probability (optional). + + Returns: + Conditioning with random dropout applied. + """ + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + return batch_mul( + torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), + in_tensor, + ) + + def get_data_and_condition(self, data_batch: dict[str, Tensor], dropout_rate=0.2) -> Tuple[Tensor]: + """ + Retrieves data and conditioning for model input. + + Args: + data_batch: Batch of input data. + dropout_rate: Dropout probability for conditioning. + + Returns: + Raw data, latent data, and conditioning information. + """ + # Latent state + raw_state = data_batch["video"] * self.sigma_data + # assume data is already encoded + latent_state = raw_state + + # Condition + data_batch['crossattn_emb'] = self.random_dropout_input( + data_batch['t5_text_embeddings'], dropout_rate=dropout_rate + ) + + return raw_state, latent_state, data_batch diff --git a/nemo/collections/diffusion/scripts/train.sh b/nemo/collections/diffusion/scripts/train.sh new file mode 100644 index 000000000000..2150458e9376 --- /dev/null +++ b/nemo/collections/diffusion/scripts/train.sh @@ -0,0 +1,29 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +# example slurm script for training diffusion + +#SBATCH -p your_partition -A your_account -t 24:00:00 --nodes=16 --exclusive --mem=0 --overcommit --gpus-per-node 8 --ntasks-per-node=8 --dependency=singleton + +export WANDB_PROJECT=xxx +export WANDB_RUN_ID=xxx +export WANDB_RESUME=allow +export NVTE_FUSED_ATTN=0 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +DIR=`pwd` + +srun -l --container-image nvcr.io/nvidia/nemo:dev --container-mounts "/home:/home" --no-container-mount-home --mpi=pmix bash -c "cd ${DIR} ; python -u nemo/collections/diffusion/train.py --yes $*" diff --git a/nemo/collections/diffusion/train.py b/nemo/collections/diffusion/train.py new file mode 100644 index 000000000000..43a0a5dcb536 --- /dev/null +++ b/nemo/collections/diffusion/train.py @@ -0,0 +1,201 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule +from nemo.collections.diffusion.data.diffusion_taskencoder import BasicDiffusionTaskEncoder +from nemo.collections.diffusion.models.model import ( + DiT7BConfig, + DiTConfig, + DiTLConfig, + DiTLlama5BConfig, + DiTLlama30BConfig, + DiTModel, + DiTXLConfig, +) +from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.strategies.utils import RestoreConfig + + +@run.cli.factory +@run.autoconvert +def multimodal_datamodule() -> pl.LightningDataModule: + data_module = DiffusionDataModule( + seq_length=2048, + task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), + micro_batch_size=1, + global_batch_size=32, + ) + return data_module + + +@run.cli.factory +@run.autoconvert +def peft(args) -> ModelTransform: + return llm.peft.LoRA( + target_modules=['linear_qkv', 'linear_proj'], # , 'linear_fc1', 'linear_fc2'], + dim=args.lora_dim, + ) + + +@run.cli.factory(target=llm.train) +def pretrain() -> run.Partial: + return run.Partial( + llm.train, + model=run.Config( + DiTModel, + config=run.Config(DiTConfig), + ), + data=multimodal_datamodule(), + trainer=run.Config( + nl.Trainer, + devices='auto', + num_nodes=int(os.environ.get('SLURM_NNODES', 1)), + accelerator="gpu", + strategy=run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.bfloat16, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + ), + ), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + num_sanity_val_steps=0, + limit_val_batches=1, + val_check_interval=1000, + max_epochs=10000, + log_every_n_steps=1, + callbacks=[ + run.Config( + ModelCheckpoint, + monitor='reduced_train_loss', + filename='{epoch}-{step}', + every_n_train_steps=1000, + save_top_k=-1, + ), + run.Config(PreemptionCallback), + ], + ), + log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None)), + optim=run.Config( + nl.MegatronOptimizerModule, + config=run.Config( + OptimizerConfig, + lr=1e-4, + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=True, + weight_decay=0, + ), + ), + tokenizer=None, + resume=run.Config( + nl.AutoResume, + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_past_end=True, + ), + model_transform=None, + ) + + +@run.cli.factory(target=llm.train) +def pretrain_xl() -> run.Partial: + recipe = pretrain() + recipe.model.config = run.Config(DiTXLConfig) + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_l() -> run.Partial: + recipe = pretrain() + recipe.model.config = run.Config(DiTLConfig) + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_7b() -> run.Partial: + recipe = pretrain() + recipe.model.config = run.Config(DiT7BConfig) + recipe.data.global_batch_size = 4608 + recipe.data.micro_batch_size = 9 + recipe.data.num_workers = 15 + recipe.data.use_train_split_for_val = True + recipe.data.seq_length = 260 + recipe.data.task_encoder.seq_length = 260 + recipe.trainer.val_check_interval = 1000 + recipe.log.log_dir = 'nemo_experiments/dit7b' + recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9) + recipe.optim.config.weight_decay = 0.1 + recipe.optim.config.adam_beta1 = 0.9 + recipe.optim.config.adam_beta2 = 0.95 + + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama5b() -> run.Partial: + recipe = pretrain_7b() + recipe.data.micro_batch_size = 12 + recipe.model.config = run.Config(DiTLlama5BConfig) + recipe.log.log_dir = 'nemo_experiments/ditllama5b' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b() -> run.Partial: + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data.global_batch_size = 9216 + recipe.data.micro_batch_size = 6 + recipe.log.log_dir = 'nemo_experiments/ditllama30b' + return recipe + + +@run.cli.factory(target=llm.train) +def dreambooth() -> run.Partial: + recipe = pretrain() + recipe.optim.config.lr = 1e-6 + recipe.data = multimodal_datamodule() + recipe.model.config = run.Config(DiTConfig) + + recipe.trainer.max_steps = 1000 + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + + recipe.resume.restore_config = run.Config(RestoreConfig) + recipe.resume.resume_if_exists = False + + return recipe + + +if __name__ == "__main__": + run.cli.main(llm.train, default_factory=dreambooth) diff --git a/nemo/collections/diffusion/vae/__init__.py b/nemo/collections/diffusion/vae/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/vae/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/vae/diffusers_vae.py b/nemo/collections/diffusion/vae/diffusers_vae.py new file mode 100644 index 000000000000..19a056d4a682 --- /dev/null +++ b/nemo/collections/diffusion/vae/diffusers_vae.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers import AutoencoderKL +from einops import rearrange + + +class AutoencoderKLVAE(torch.nn.Module): + def __init__(self, path): + super().__init__() + self.vae = AutoencoderKL.from_pretrained(path, torch_dtype=torch.bfloat16) + + @torch.no_grad() + def decode(self, x): + B, C, T, H, W = x.shape + if T == 1: + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = x / self.vae.config.scaling_factor + out = self.vae.decode(x, return_dict=False)[0] + if T == 1: + return rearrange(out, '(b t) c h w -> b c t h w', t=1) + return out