Skip to content

Commit

Permalink
Fix trainer builder when exp_manager is not in config (#9293)
Browse files Browse the repository at this point in the history
* fix

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* rollback changes

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Co-authored-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
  • Loading branch information
yaoyu-33 and yaoyu-33 authored May 29, 2024
1 parent a1173eb commit cff6b95
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=Fal
if use_te:
activation = 'gelu' if not glu else 'geglu'
# TODO: more parameters to be confirmed, dropout, seq_length
self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,)
self.net = LayerNormMLP(
hidden_size=dim,
ffn_hidden_size=inner_dim,
activation=activation,
)
else:
norm = nn.LayerNorm(dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
Expand Down Expand Up @@ -264,7 +268,7 @@ def __init__(
self.query_dim = query_dim
self.dim_head = dim_head

self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads

self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def __init__(
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch),]
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)

# downsampling
Expand Down Expand Up @@ -669,7 +672,11 @@ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
]
)

self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1,)
self.conv_out = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
)

def forward(self, x):
x = self.conv_in(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,14 @@ class AttentionPool2d(nn.Module):
"""

def __init__(
self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None,
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
Expand Down Expand Up @@ -332,7 +336,10 @@ def __init__(
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(nn.SiLU(), linear(emb_channels, self.emb_out_channels),)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(emb_channels, self.emb_out_channels),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, act="silu", gn_groups=resblock_gn_groups),
nn.Dropout(p=dropout),
Expand Down Expand Up @@ -400,7 +407,12 @@ class AttentionBlock(nn.Module):
"""

def __init__(
self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False, use_new_attention_order=False,
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
Expand Down Expand Up @@ -451,7 +463,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])


Expand Down Expand Up @@ -653,7 +665,10 @@ def __init__(
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(
map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)),)
map(
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
range(len(num_attention_blocks)),
)
)
logging.info(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
Expand All @@ -674,7 +689,9 @@ def __init__(
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim),
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

self.time_embeddings = torch.Tensor(build_timestep_embedding(model_channels, timesteps))
Expand All @@ -691,15 +708,19 @@ def __init__(
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim),
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.adm_in_channels = adm_in_channels
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim),
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
Expand Down Expand Up @@ -810,26 +831,28 @@ def __init__(
use_scale_shift_norm=use_scale_shift_norm,
resblock_gn_groups=resblock_gn_groups,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
)
),
ResBlock(
ch,
Expand Down Expand Up @@ -1123,9 +1146,15 @@ def te_fp8_key_mapping(self, unet_dict):
# norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias}
# norm_to_q.weight -> to_q.weight
new_key = key.replace('attn1.norm.', 'attn1.norm_to_q.layer_norm_')
new_key = new_key.replace('attn1.to_q.weight', 'attn1.norm_to_q.weight',)
new_key = new_key.replace(
'attn1.to_q.weight',
'attn1.norm_to_q.weight',
)
new_key = new_key.replace('attn2.norm.', 'attn2.norm_to_q.layer_norm_')
new_key = new_key.replace('attn2.to_q.weight', 'attn2.norm_to_q.weight',)
new_key = new_key.replace(
'attn2.to_q.weight',
'attn2.norm_to_q.weight',
)

### LayerNormMLP
# ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias}
Expand Down Expand Up @@ -1214,7 +1243,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

def _find_mismatched_keys(
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes,
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
Expand All @@ -1234,7 +1266,10 @@ def _find_mismatched_keys(
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes,
state_dict,
model_state_dict,
original_loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = self._load_state_dict_into_model(state_dict)
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
Expand Down Expand Up @@ -1329,9 +1364,14 @@ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs):
return self.out(h)

def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe,
) if self.use_te_fp8 else nullcontext():
with (
transformer_engine.pytorch.fp8_autocast(
enabled=self.use_te_fp8,
fp8_recipe=self.fp8_recipe,
)
if self.use_te_fp8
else nullcontext()
):
out = self._forward(x, timesteps, context, y, **kwargs)
return out

Expand Down Expand Up @@ -1387,7 +1427,9 @@ def __init__(

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim),
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

self.input_blocks = nn.ModuleList(
Expand Down Expand Up @@ -1489,11 +1531,15 @@ def __init__(
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch), nn.SiLU(), AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
normalization(ch),
nn.SiLU(),
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels),
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2

elif schedule == "cosine":
timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
Expand Down Expand Up @@ -169,7 +169,10 @@ def backward(ctx, *output_grads):
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True,
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
Expand Down Expand Up @@ -319,15 +322,23 @@ def interpolate_fn(x, xp, yp):
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _plugins(self) -> list:
use_dist_ckpt = not self.cfg.model.get('fsdp', False) and (
self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False)
)
async_save = self.cfg.exp_manager.get('checkpoint_callback_params', {}).get('async_save', False)
async_save = self.cfg.get('exp_manager', {}).get('checkpoint_callback_params', {}).get('async_save', False)
if use_dist_ckpt:
checkpoint_io = DistributedCheckpointIO.from_config(self.cfg.model, async_save)
if async_save:
Expand All @@ -171,7 +171,7 @@ def _callbacks(self, callbacks: Optional[list]) -> list:
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())

if self.cfg.exp_manager.get('checkpoint_callback_params', {}).get('async_save', False):
if self.cfg.get('exp_manager', {}).get('checkpoint_callback_params', {}).get('async_save', False):
callbacks.append(AsyncFinalizerCallback())
return callbacks

Expand Down

0 comments on commit cff6b95

Please sign in to comment.