Skip to content

Commit

Permalink
Fix get_parameters when using main params optimizer (#6764) (#6787)
Browse files Browse the repository at this point in the history
* fix get param



* change name



---------

Signed-off-by: ericharper <complex451@gmail.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
github-actions[bot] and ericharper authored Jun 2, 2023
1 parent d5819e9 commit ef74006
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,16 @@ def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by
)
return after

def _get_parameters(self):
def get_parameters_with_grad(self):
"""
private method to load all the trainable parameters from optimizer param groups
Get all parameters with grad from optimizer param groups
"""
params = []
for param_group in self._optimizer_param_groups:
for param in param_group['params']:
if param.requires_grad: # (@adithyare) adapter training with pp>1 can result in params with no grads
if (
param.grad is not None
): # (@adithyare) adapter training with pp>1 can result in params with no grads
params.append(param)
return params

Expand All @@ -272,9 +274,9 @@ def configure_gradient_clipping(self, *args, **kwargs):
else:
if self.megatron_amp_o2:
# grep fp32 master parameters for gradient clipping
parameters = self._optimizer.get_parameters()
parameters = self._optimizer.get_parameters_with_grad()
else:
parameters = self._get_parameters()
parameters = self.get_parameters_with_grad()
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)

self.log('grad_norm', grad_norm, rank_zero_only=True, batch_size=1)
Expand Down
4 changes: 2 additions & 2 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,11 @@ def async_master_grads_allreudce(self):
def fp32_grad_accumulation(self):
return self._fp32_grad_accum

def get_parameters(self):
def get_parameters_with_grad(self):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
if param.requires_grad: # (@adithyare) added to enable pp>1 training for adapters
if param.grad is not None: # (@adithyare) added to enable pp>1 training for adapters
params.append(param)
return params

Expand Down

0 comments on commit ef74006

Please sign in to comment.