Skip to content

Commit

Permalink
Diffusion Transformer Training Pipeline (NVIDIA#10843)
Browse files Browse the repository at this point in the history
* diffusion training

Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>

* fixing issues with data module

Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>

* added dit llama support, cleaned up dit code

Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>

* fixed code formatting

Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>

* added dit llama models

Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>

---------

Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>
  • Loading branch information
zpx01 authored Oct 13, 2024
1 parent 44aa545 commit ce21ffb
Show file tree
Hide file tree
Showing 22 changed files with 3,088 additions and 15 deletions.
19 changes: 18 additions & 1 deletion nemo/collections/diffusion/data/diffusion_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal

import logging
from typing import Any, Dict, Literal

from megatron.energon import DefaultTaskEncoder, get_train_dataset
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
Expand Down Expand Up @@ -127,3 +129,18 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
if self.use_train_split_for_val:
return self.train_dataloader()
return super().val_dataloader()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Load the state of the data module from a checkpoint.
This method is called when loading a checkpoint. It restores the state of the data module,
including the state of the dataloader and the number of consumed samples.
Parameters:
state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module.
"""
try:
super().load_state_dict(state_dict)
except Exception as e:
logging.warning(f"datamodule.load_state_dict failed {e}")
120 changes: 106 additions & 14 deletions nemo/collections/diffusion/data/diffusion_taskencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, SkipSample
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys
Expand Down Expand Up @@ -66,10 +69,22 @@ class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
Cooker(cook),
]

def __init__(self, *args, max_frames: int = None, text_embedding_padding_size: int = 512, **kwargs):
def __init__(
self,
*args,
max_frames: int = None,
text_embedding_padding_size: int = 512,
seq_length: int = None,
patch_spatial: int = 2,
patch_temporal: int = 1,
**kwargs,
):
super().__init__(*args, **kwargs)
self.max_frames = max_frames
self.text_embedding_padding_size = text_embedding_padding_size
self.seq_length = seq_length
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal

def encode_sample(self, sample: dict) -> dict:
video_latent = sample['pth']
Expand All @@ -80,42 +95,119 @@ def encode_sample(self, sample: dict) -> dict:
raise SkipSample()

info = sample['json']
_, T, H, W = video_latent.shape
C, T, H, W = video_latent.shape
seq_len = (
video_latent.shape[-1]
* video_latent.shape[-2]
* video_latent.shape[-3]
// self.patch_spatial**2
// self.patch_temporal
)
is_image = T == 1

if seq_len > self.seq_length:
raise SkipSample()

if self.max_frames is not None:
video_latent = video_latent[:, : self.max_frames, :, :]

tpcp_size = parallel_state.get_tensor_model_parallel_world_size()
if parallel_state.get_context_parallel_world_size() > 1:
tpcp_size *= parallel_state.get_context_parallel_world_size() * 2
if (T * H * W) % tpcp_size != 0:
print(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}')
warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}')
raise SkipSample()

seq_len = video_latent.shape[-1] * video_latent.shape[-2] * video_latent.shape[-3]
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
video_latent = rearrange(
video_latent,
'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)',
ph=self.patch_spatial,
pw=self.patch_spatial,
pt=self.patch_temporal,
)

if is_image:
t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16)
else:
t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16)
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]

t5_text_embeddings = F.pad(
t5_text_embeddings,
(
0,
0,
0,
self.text_embedding_padding_size - t5_text_embeddings_seq_length % self.text_embedding_padding_size,
),
)
if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
else:
t5_text_embeddings = F.pad(
t5_text_embeddings,
(
0,
0,
0,
self.text_embedding_padding_size - t5_text_embeddings_seq_length,
),
)
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)

if is_image:
h, w = info['image_height'], info['image_width']
fps = torch.tensor([30] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16)
else:
h, w = info['height'], info['width']
fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16)
image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16)

pos_ids = rearrange(
pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial),
'T H W d -> (T H W) d',
)

if self.seq_length is not None:
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
loss_mask[:seq_len] = 1
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
else:
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)

return dict(
video=video_latent,
t5_text_embeddings=t5_text_embeddings,
t5_text_mask=t5_text_mask,
image_size=image_size,
fps=fps,
num_frames=num_frames,
loss_mask=loss_mask,
seq_len_q=torch.tensor(seq_len, dtype=torch.int32),
seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32),
pos_ids=pos_ids,
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
)


class PosID3D:
def __init__(self, *, max_t=32, max_h=128, max_w=128):
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.generate_pos_id()

def generate_pos_id(self):
self.grid = torch.stack(
torch.meshgrid(
torch.arange(self.max_t, device='cpu'),
torch.arange(self.max_h, device='cpu'),
torch.arange(self.max_w, device='cpu'),
),
dim=-1,
)

def get_pos_id_3d(self, *, t, h, w):
if t > self.max_t or h > self.max_h or w > self.max_w:
self.max_t = max(self.max_t, t)
self.max_h = max(self.max_h, h)
self.max_w = max(self.max_w, w)
self.generate_pos_id()
return self.grid[:t, :h, :w]


pos_id_3d = PosID3D()
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/models/dit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
161 changes: 161 additions & 0 deletions nemo/collections/diffusion/models/dit/dit_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Dict, Literal, Optional

import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, get_3d_sincos_pos_embed
from einops import rearrange
from einops.layers.torch import Rearrange
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rotary_pos_embedding import get_pos_emb_on_this_cp_rank
from megatron.core.transformer.module import MegatronModule
from torch import nn


class ParallelTimestepEmbedding(TimestepEmbedding):
"""
ParallelTimestepEmbedding is a subclass of TimestepEmbedding that initializes
the embedding layers with an optional random seed for syncronization.
Args:
in_channels (int): Number of input channels.
time_embed_dim (int): Dimension of the time embedding.
seed (int, optional): Random seed for initializing the embedding layers.
If None, no specific seed is set.
Attributes:
linear_1 (nn.Module): First linear layer for the embedding.
linear_2 (nn.Module): Second linear layer for the embedding.
Methods:
__init__(in_channels, time_embed_dim, seed=None): Initializes the embedding layers.
"""

def __init__(self, in_channels: int, time_embed_dim: int, seed=None):
super().__init__(in_channels=in_channels, time_embed_dim=time_embed_dim)
if seed is not None:
with torch.random.fork_rng():
torch.manual_seed(seed)
self.linear_1.reset_parameters()
self.linear_2.reset_parameters()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes the positional embeddings for the input tensor.
Args:
x (torch.Tensor): Input tensor of shape (B, T, H, W, C).
Returns:
torch.Tensor: Positional embeddings of shape (B, T, H, W, C).
"""
return super().forward(x.to(torch.bfloat16, non_blocking=True))


def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
"""
Adjusts the positional embeddings tensor to the current context parallel rank.
Args:
pos_emb (torch.Tensor): The positional embeddings tensor.
seq_dim (int): The sequence dimension index in the positional embeddings tensor.
Returns:
torch.Tensor: The adjusted positional embeddings tensor for the current context parallel rank.
"""
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor([cp_rank], device="cpu", pin_memory=True).cuda(non_blocking=True)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], cp_size, -1, *pos_emb.shape[(seq_dim + 1) :])
pos_emb = pos_emb.index_select(seq_dim, cp_idx)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb


class SinCosPosEmb3D(MegatronModule):
"""
SinCosPosEmb3D is a 3D sine-cosine positional embedding module.
Args:
model_channels (int): Number of channels in the model.
h (int): Length of the height dimension.
w (int): Length of the width dimension.
t (int): Length of the temporal dimension.
spatial_interpolation_scale (float, optional): Scale factor for spatial interpolation. Default is 1.0.
temporal_interpolation_scale (float, optional): Scale factor for temporal interpolation. Default is 1.0.
Methods:
forward(pos_ids: torch.Tensor) -> torch.Tensor:
Computes the positional embeddings for the input tensor.
Args:
pos_ids (torch.Tensor): Input tensor of shape (B S 3).
Returns:
torch.Tensor: Positional embeddings of shape (B S D).
"""

def __init__(
self,
config,
h: int,
w: int,
t: int,
spatial_interpolation_scale=1.0,
temporal_interpolation_scale=1.0,
):
super().__init__(config=config)
self.h = h
self.w = w
self.t = t
# h w t
param = get_3d_sincos_pos_embed(
config.hidden_size, [h, w], t, spatial_interpolation_scale, temporal_interpolation_scale
)
param = rearrange(param, "t hw c -> (t hw) c")
self.pos_embedding = torch.nn.Embedding(param.shape[0], config.hidden_size)
self.pos_embedding.weight = torch.nn.Parameter(torch.tensor(param), requires_grad=False)

def forward(self, pos_ids: torch.Tensor):
# pos_ids: t h w
pos_id = pos_ids[..., 0] * self.h * self.w + pos_ids[..., 1] * self.w + pos_ids[..., 2]
return self.pos_embedding(pos_id)


class FactorizedLearnable3DEmbedding(MegatronModule):
def __init__(
self,
config,
t: int,
h: int,
w: int,
**kwargs,
):
super().__init__(config=config)
self.emb_t = torch.nn.Embedding(t, config.hidden_size)
self.emb_h = torch.nn.Embedding(h, config.hidden_size)
self.emb_w = torch.nn.Embedding(w, config.hidden_size)

if config.perform_initialization:
config.init_method(self.emb_t.weight)
config.init_method(self.emb_h.weight)
config.init_method(self.emb_w.weight)

def forward(self, pos_ids: torch.Tensor):
return self.emb_t(pos_ids[..., 0]) + self.emb_h(pos_ids[..., 1]) + self.emb_w(pos_ids[..., 2])
Loading

0 comments on commit ce21ffb

Please sign in to comment.