From 7ddeceabde0a351ca43c89f2b91ec0066216f7ba Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:02:41 -0700 Subject: [PATCH] Fix decoding bug for megatron enc-dec models with O2 (#4989) (#5031) * Change inspect Signed-off-by: MaximumEntropy * Fix cross attention RPE Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Co-authored-by: Oleksii Kuchaiev Signed-off-by: MaximumEntropy Co-authored-by: Sandeep Subramanian Co-authored-by: Oleksii Kuchaiev Signed-off-by: Hainan Xu --- .../language_modeling/megatron_lm_encoder_decoder_model.py | 3 ++- .../nlp/modules/common/megatron/token_level_encoder_decoder.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 183345ef5851..3871de40e8c7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -480,7 +480,8 @@ def _kwargs_to_arg_idx(self): Computed on first call, and then cached. """ # build mapping of kwargs to arg index at first run - args_name = inspect.getfullargspec(self.enc_dec_model.forward)[0][1:] + module = self.enc_dec_model.forward if not self.megatron_amp_o2 else self.enc_dec_model.module.forward + args_name = inspect.getfullargspec(module)[0][1:] kwargs_to_arg_idx = {k: v for k, v in zip(args_name, range(len(args_name)))} return kwargs_to_arg_idx diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b0e788d0b77a..d0d7e2119da8 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -472,7 +472,7 @@ def forward( ) if not self.decoder_cfg.relative_position_bias_self_attention_only: decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( - query_seq_length=dec_input_ids.size(1), key_seq_length=enc_input_ids.size(1), + query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, ) else: decoder_cross_attention_relative_position_bias = None