Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusion Transformer Training Pipeline #10843

Merged
merged 6 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion nemo/collections/diffusion/data/diffusion_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal

import logging
from typing import Any, Dict, Literal

from megatron.energon import DefaultTaskEncoder, get_train_dataset
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
Expand Down Expand Up @@ -127,3 +129,18 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
if self.use_train_split_for_val:
return self.train_dataloader()
return super().val_dataloader()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Load the state of the data module from a checkpoint.

This method is called when loading a checkpoint. It restores the state of the data module,
including the state of the dataloader and the number of consumed samples.

Parameters:
state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module.
"""
try:
super().load_state_dict(state_dict)
except Exception as e:
logging.warning(f"datamodule.load_state_dict failed {e}")
120 changes: 106 additions & 14 deletions nemo/collections/diffusion/data/diffusion_taskencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, SkipSample
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys
Expand Down Expand Up @@ -66,10 +69,22 @@ class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
Cooker(cook),
]

def __init__(self, *args, max_frames: int = None, text_embedding_padding_size: int = 512, **kwargs):
def __init__(
self,
*args,
max_frames: int = None,
text_embedding_padding_size: int = 512,
seq_length: int = None,
patch_spatial: int = 2,
patch_temporal: int = 1,
**kwargs,
):
super().__init__(*args, **kwargs)
self.max_frames = max_frames
self.text_embedding_padding_size = text_embedding_padding_size
self.seq_length = seq_length
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal

def encode_sample(self, sample: dict) -> dict:
video_latent = sample['pth']
Expand All @@ -80,42 +95,119 @@ def encode_sample(self, sample: dict) -> dict:
raise SkipSample()

info = sample['json']
_, T, H, W = video_latent.shape
C, T, H, W = video_latent.shape
seq_len = (
video_latent.shape[-1]
* video_latent.shape[-2]
* video_latent.shape[-3]
// self.patch_spatial**2
// self.patch_temporal
)
is_image = T == 1

if seq_len > self.seq_length:
raise SkipSample()

if self.max_frames is not None:
video_latent = video_latent[:, : self.max_frames, :, :]

tpcp_size = parallel_state.get_tensor_model_parallel_world_size()
if parallel_state.get_context_parallel_world_size() > 1:
tpcp_size *= parallel_state.get_context_parallel_world_size() * 2
if (T * H * W) % tpcp_size != 0:
print(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}')
warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}')
raise SkipSample()

seq_len = video_latent.shape[-1] * video_latent.shape[-2] * video_latent.shape[-3]
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
video_latent = rearrange(
video_latent,
'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)',
ph=self.patch_spatial,
pw=self.patch_spatial,
pt=self.patch_temporal,
)

if is_image:
t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16)
else:
t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16)
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]

t5_text_embeddings = F.pad(
t5_text_embeddings,
(
0,
0,
0,
self.text_embedding_padding_size - t5_text_embeddings_seq_length % self.text_embedding_padding_size,
),
)
if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
else:
t5_text_embeddings = F.pad(
t5_text_embeddings,
(
0,
0,
0,
self.text_embedding_padding_size - t5_text_embeddings_seq_length,
),
)
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)

if is_image:
h, w = info['image_height'], info['image_width']
fps = torch.tensor([30] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16)
else:
h, w = info['height'], info['width']
fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16)
image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16)

pos_ids = rearrange(
pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial),
'T H W d -> (T H W) d',
)

if self.seq_length is not None:
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
loss_mask[:seq_len] = 1
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
else:
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)

return dict(
video=video_latent,
t5_text_embeddings=t5_text_embeddings,
t5_text_mask=t5_text_mask,
image_size=image_size,
fps=fps,
num_frames=num_frames,
loss_mask=loss_mask,
seq_len_q=torch.tensor(seq_len, dtype=torch.int32),
seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32),
pos_ids=pos_ids,
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
)


class PosID3D:
def __init__(self, *, max_t=32, max_h=128, max_w=128):
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.generate_pos_id()

def generate_pos_id(self):
self.grid = torch.stack(
torch.meshgrid(
torch.arange(self.max_t, device='cpu'),
torch.arange(self.max_h, device='cpu'),
torch.arange(self.max_w, device='cpu'),
),
dim=-1,
)

def get_pos_id_3d(self, *, t, h, w):
if t > self.max_t or h > self.max_h or w > self.max_w:
self.max_t = max(self.max_t, t)
self.max_h = max(self.max_h, h)
self.max_w = max(self.max_w, w)
self.generate_pos_id()
return self.grid[:t, :h, :w]


pos_id_3d = PosID3D()
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/models/dit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
161 changes: 161 additions & 0 deletions nemo/collections/diffusion/models/dit/dit_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'math' is not used.
from typing import Dict, Literal, Optional

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Dict' is not used.
Import of 'Literal' is not used.
Import of 'Optional' is not used.

import numpy as np

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'np' is not used.
import torch
import torch.nn.functional as F

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'F' is not used.
from diffusers.models.embeddings import TimestepEmbedding, get_3d_sincos_pos_embed
from einops import rearrange
from einops.layers.torch import Rearrange

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Rearrange' is not used.
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rotary_pos_embedding import get_pos_emb_on_this_cp_rank

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'get_pos_emb_on_this_cp_rank' is not used.
from megatron.core.transformer.module import MegatronModule
from torch import nn

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'nn' is not used.


class ParallelTimestepEmbedding(TimestepEmbedding):
"""
ParallelTimestepEmbedding is a subclass of TimestepEmbedding that initializes
the embedding layers with an optional random seed for syncronization.

Args:
in_channels (int): Number of input channels.
time_embed_dim (int): Dimension of the time embedding.
seed (int, optional): Random seed for initializing the embedding layers.
If None, no specific seed is set.

Attributes:
linear_1 (nn.Module): First linear layer for the embedding.
linear_2 (nn.Module): Second linear layer for the embedding.

Methods:
__init__(in_channels, time_embed_dim, seed=None): Initializes the embedding layers.
"""

def __init__(self, in_channels: int, time_embed_dim: int, seed=None):
super().__init__(in_channels=in_channels, time_embed_dim=time_embed_dim)
if seed is not None:
with torch.random.fork_rng():
torch.manual_seed(seed)
self.linear_1.reset_parameters()
self.linear_2.reset_parameters()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes the positional embeddings for the input tensor.

Args:
x (torch.Tensor): Input tensor of shape (B, T, H, W, C).

Returns:
torch.Tensor: Positional embeddings of shape (B, T, H, W, C).
"""
return super().forward(x.to(torch.bfloat16, non_blocking=True))


def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
"""
Adjusts the positional embeddings tensor to the current context parallel rank.

Args:
pos_emb (torch.Tensor): The positional embeddings tensor.
seq_dim (int): The sequence dimension index in the positional embeddings tensor.

Returns:
torch.Tensor: The adjusted positional embeddings tensor for the current context parallel rank.
"""
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor([cp_rank], device="cpu", pin_memory=True).cuda(non_blocking=True)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], cp_size, -1, *pos_emb.shape[(seq_dim + 1) :])
pos_emb = pos_emb.index_select(seq_dim, cp_idx)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb


class SinCosPosEmb3D(MegatronModule):
"""
SinCosPosEmb3D is a 3D sine-cosine positional embedding module.

Args:
model_channels (int): Number of channels in the model.
h (int): Length of the height dimension.
w (int): Length of the width dimension.
t (int): Length of the temporal dimension.
spatial_interpolation_scale (float, optional): Scale factor for spatial interpolation. Default is 1.0.
temporal_interpolation_scale (float, optional): Scale factor for temporal interpolation. Default is 1.0.

Methods:
forward(pos_ids: torch.Tensor) -> torch.Tensor:
Computes the positional embeddings for the input tensor.

Args:
pos_ids (torch.Tensor): Input tensor of shape (B S 3).

Returns:
torch.Tensor: Positional embeddings of shape (B S D).
"""

def __init__(
self,
config,
h: int,
w: int,
t: int,
spatial_interpolation_scale=1.0,
temporal_interpolation_scale=1.0,
):
super().__init__(config=config)
self.h = h
self.w = w
self.t = t
# h w t
param = get_3d_sincos_pos_embed(
config.hidden_size, [h, w], t, spatial_interpolation_scale, temporal_interpolation_scale
)
param = rearrange(param, "t hw c -> (t hw) c")
self.pos_embedding = torch.nn.Embedding(param.shape[0], config.hidden_size)
self.pos_embedding.weight = torch.nn.Parameter(torch.tensor(param), requires_grad=False)

def forward(self, pos_ids: torch.Tensor):
# pos_ids: t h w
pos_id = pos_ids[..., 0] * self.h * self.w + pos_ids[..., 1] * self.w + pos_ids[..., 2]
return self.pos_embedding(pos_id)


class FactorizedLearnable3DEmbedding(MegatronModule):
def __init__(
self,
config,
t: int,
h: int,
w: int,
**kwargs,
):
super().__init__(config=config)
self.emb_t = torch.nn.Embedding(t, config.hidden_size)
self.emb_h = torch.nn.Embedding(h, config.hidden_size)
self.emb_w = torch.nn.Embedding(w, config.hidden_size)

if config.perform_initialization:
config.init_method(self.emb_t.weight)
config.init_method(self.emb_h.weight)
config.init_method(self.emb_w.weight)

def forward(self, pos_ids: torch.Tensor):
return self.emb_t(pos_ids[..., 0]) + self.emb_h(pos_ids[..., 1]) + self.emb_w(pos_ids[..., 2])
Loading
Loading