diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 183345ef5851..3871de40e8c7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -480,7 +480,8 @@ def _kwargs_to_arg_idx(self): Computed on first call, and then cached. """ # build mapping of kwargs to arg index at first run - args_name = inspect.getfullargspec(self.enc_dec_model.forward)[0][1:] + module = self.enc_dec_model.forward if not self.megatron_amp_o2 else self.enc_dec_model.module.forward + args_name = inspect.getfullargspec(module)[0][1:] kwargs_to_arg_idx = {k: v for k, v in zip(args_name, range(len(args_name)))} return kwargs_to_arg_idx diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b0e788d0b77a..d0d7e2119da8 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -472,7 +472,7 @@ def forward( ) if not self.decoder_cfg.relative_position_bias_self_attention_only: decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( - query_seq_length=dec_input_ids.size(1), key_seq_length=enc_input_ids.size(1), + query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, ) else: decoder_cross_attention_relative_position_bias = None