From 5b4086ed47a61ca4498d5ad3faee56f8f1eae11f Mon Sep 17 00:00:00 2001 From: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Date: Wed, 29 May 2024 09:36:00 -0700 Subject: [PATCH] Fix trainer builder when exp_manager is not in config (#9293) * fix Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * rollback changes Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 Co-authored-by: yaoyu-33 Signed-off-by: Boxiang Wang --- .../modules/stable_diffusion/attention.py | 8 +- .../diffusionmodules/model.py | 11 +- .../diffusionmodules/openaimodel.py | 124 ++++++++++++------ .../stable_diffusion/diffusionmodules/util.py | 19 ++- .../nlp/parts/megatron_trainer_builder.py | 4 +- 5 files changed, 117 insertions(+), 49 deletions(-) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index c70b59d394817..2eeed97db7810 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -122,7 +122,11 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=Fal if use_te: activation = 'gelu' if not glu else 'geglu' # TODO: more parameters to be confirmed, dropout, seq_length - self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,) + self.net = LayerNormMLP( + hidden_size=dim, + ffn_hidden_size=inner_dim, + activation=activation, + ) else: norm = nn.LayerNorm(dim) project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) @@ -264,7 +268,7 @@ def __init__( self.query_dim = query_dim self.dim_head = dim_head - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py index 644efafaf06a5..5b874f5f10adc 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py @@ -233,7 +233,10 @@ def __init__( # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList( - [torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch),] + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] ) # downsampling @@ -669,7 +672,11 @@ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): ] ) - self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1,) + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 3e301f0b8fc19..30ff0e1a9ff30 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -115,10 +115,14 @@ class AttentionPool2d(nn.Module): """ def __init__( - self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None, + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -332,7 +336,10 @@ def __init__( self.emb_layers = None self.exchange_temb_dims = False else: - self.emb_layers = nn.Sequential(nn.SiLU(), linear(emb_channels, self.emb_out_channels),) + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear(emb_channels, self.emb_out_channels), + ) self.out_layers = nn.Sequential( normalization(self.out_channels, act="silu", gn_groups=resblock_gn_groups), nn.Dropout(p=dropout), @@ -400,7 +407,12 @@ class AttentionBlock(nn.Module): """ def __init__( - self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False, use_new_attention_order=False, + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, ): super().__init__() self.channels = channels @@ -451,7 +463,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -653,7 +665,10 @@ def __init__( if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all( - map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)),) + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) ) logging.info( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " @@ -674,7 +689,9 @@ def __init__( self.predict_codebook_ids = n_embed is not None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), ) self.time_embeddings = torch.Tensor(build_timestep_embedding(model_channels, timesteps)) @@ -691,7 +708,9 @@ def __init__( self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( - linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": @@ -699,7 +718,9 @@ def __init__( self.adm_in_channels = adm_in_channels self.label_emb = nn.Sequential( nn.Sequential( - linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), ) ) else: @@ -810,26 +831,28 @@ def __init__( use_scale_shift_norm=use_scale_shift_norm, resblock_gn_groups=resblock_gn_groups, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth_middle, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, - use_flash_attention=use_flash_attention, - use_te=self.use_te_fp8, - lora_network_alpha=lora_network_alpha, + ( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + use_te=self.use_te_fp8, + lora_network_alpha=lora_network_alpha, + ) ), ResBlock( ch, @@ -1123,9 +1146,15 @@ def te_fp8_key_mapping(self, unet_dict): # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias} # norm_to_q.weight -> to_q.weight new_key = key.replace('attn1.norm.', 'attn1.norm_to_q.layer_norm_') - new_key = new_key.replace('attn1.to_q.weight', 'attn1.norm_to_q.weight',) + new_key = new_key.replace( + 'attn1.to_q.weight', + 'attn1.norm_to_q.weight', + ) new_key = new_key.replace('attn2.norm.', 'attn2.norm_to_q.layer_norm_') - new_key = new_key.replace('attn2.to_q.weight', 'attn2.norm_to_q.weight',) + new_key = new_key.replace( + 'attn2.to_q.weight', + 'attn2.norm_to_q.weight', + ) ### LayerNormMLP # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias} @@ -1214,7 +1243,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from unexpected_keys = list(set(loaded_keys) - set(expected_keys)) def _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: @@ -1234,7 +1266,10 @@ def _find_mismatched_keys( if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( - state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, ) error_msgs = self._load_state_dict_into_model(state_dict) return missing_keys, unexpected_keys, mismatched_keys, error_msgs @@ -1329,9 +1364,14 @@ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs): return self.out(h) def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - with transformer_engine.pytorch.fp8_autocast( - enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe, - ) if self.use_te_fp8 else nullcontext(): + with ( + transformer_engine.pytorch.fp8_autocast( + enabled=self.use_te_fp8, + fp8_recipe=self.fp8_recipe, + ) + if self.use_te_fp8 + else nullcontext() + ): out = self._forward(x, timesteps, context, y, **kwargs) return out @@ -1387,7 +1427,9 @@ def __init__( time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), ) self.input_blocks = nn.ModuleList( @@ -1489,11 +1531,15 @@ def __init__( elif pool == "attention": assert num_head_channels != -1 self.out = nn.Sequential( - normalization(ch), nn.SiLU(), AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + normalization(ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), ) elif pool == "spatial": self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels), + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), ) elif pool == "spatial_v2": self.out = nn.Sequential( diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py index 53f9669a0b8f3..69700a43614ec 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py @@ -44,7 +44,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": - betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 elif schedule == "cosine": timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s @@ -169,7 +169,10 @@ def backward(ctx, *output_grads): shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( - output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, ) del ctx.input_tensors del ctx.input_params @@ -319,7 +322,11 @@ def interpolate_fn(x, xp, yp): start_idx = torch.where( torch.eq(x_idx, 0), torch.tensor(1, device=x.device), - torch.where(torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), ) end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) @@ -327,7 +334,11 @@ def interpolate_fn(x, xp, yp): start_idx2 = torch.where( torch.eq(x_idx, 0), torch.tensor(0, device=x.device), - torch.where(torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index f6336f6bcc712..194168008dc4f 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -146,7 +146,7 @@ def _plugins(self) -> list: use_dist_ckpt = not self.cfg.model.get('fsdp', False) and ( self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False) ) - async_save = self.cfg.exp_manager.get('checkpoint_callback_params', {}).get('async_save', False) + async_save = self.cfg.get('exp_manager', {}).get('checkpoint_callback_params', {}).get('async_save', False) if use_dist_ckpt: checkpoint_io = DistributedCheckpointIO.from_config(self.cfg.model, async_save) if async_save: @@ -171,7 +171,7 @@ def _callbacks(self, callbacks: Optional[list]) -> list: if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar: callbacks.append(CustomProgressBar()) - if self.cfg.exp_manager.get('checkpoint_callback_params', {}).get('async_save', False): + if self.cfg.get('exp_manager', {}).get('checkpoint_callback_params', {}).get('async_save', False): callbacks.append(AsyncFinalizerCallback()) return callbacks