From ef740068e75bf55aac14c1432707fc4ef136bb04 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 17:53:33 -0600 Subject: [PATCH] Fix get_parameters when using main params optimizer (#6764) (#6787) * fix get param * change name --------- Signed-off-by: ericharper Co-authored-by: Eric Harper --- .../models/language_modeling/megatron_base_model.py | 12 +++++++----- nemo/core/optim/optimizer_with_main_params.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 1237491fa39c..2aaedbe5a806 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -240,14 +240,16 @@ def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by ) return after - def _get_parameters(self): + def get_parameters_with_grad(self): """ - private method to load all the trainable parameters from optimizer param groups + Get all parameters with grad from optimizer param groups """ params = [] for param_group in self._optimizer_param_groups: for param in param_group['params']: - if param.requires_grad: # (@adithyare) adapter training with pp>1 can result in params with no grads + if ( + param.grad is not None + ): # (@adithyare) adapter training with pp>1 can result in params with no grads params.append(param) return params @@ -272,9 +274,9 @@ def configure_gradient_clipping(self, *args, **kwargs): else: if self.megatron_amp_o2: # grep fp32 master parameters for gradient clipping - parameters = self._optimizer.get_parameters() + parameters = self._optimizer.get_parameters_with_grad() else: - parameters = self._get_parameters() + parameters = self.get_parameters_with_grad() grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) self.log('grad_norm', grad_norm, rank_zero_only=True, batch_size=1) diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index c9790ee2a139..44d54a0e63ff 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -488,11 +488,11 @@ def async_master_grads_allreudce(self): def fp32_grad_accumulation(self): return self._fp32_grad_accum - def get_parameters(self): + def get_parameters_with_grad(self): params = [] for param_group in self.optimizer.param_groups: for param in param_group['params']: - if param.requires_grad: # (@adithyare) added to enable pp>1 training for adapters + if param.grad is not None: # (@adithyare) added to enable pp>1 training for adapters params.append(param) return params