Skip to content

Commit

Permalink
Fix decoding bug for megatron enc-dec models with O2 (NVIDIA#4989) (N…
Browse files Browse the repository at this point in the history
…VIDIA#5031)

* Change inspect

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix cross attention RPE

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
3 people authored and Hainan Xu committed Nov 29, 2022
1 parent 05a26ca commit 7ddecea
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ def _kwargs_to_arg_idx(self):
Computed on first call, and then cached.
"""
# build mapping of kwargs to arg index at first run
args_name = inspect.getfullargspec(self.enc_dec_model.forward)[0][1:]
module = self.enc_dec_model.forward if not self.megatron_amp_o2 else self.enc_dec_model.module.forward
args_name = inspect.getfullargspec(module)[0][1:]
kwargs_to_arg_idx = {k: v for k, v in zip(args_name, range(len(args_name)))}

return kwargs_to_arg_idx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def forward(
)
if not self.decoder_cfg.relative_position_bias_self_attention_only:
decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding(
query_seq_length=dec_input_ids.size(1), key_seq_length=enc_input_ids.size(1),
query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length,
)
else:
decoder_cross_attention_relative_position_bias = None
Expand Down

0 comments on commit 7ddecea

Please sign in to comment.