diff --git a/nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py index b156429b4377..54276d6fa21e 100644 --- a/nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py +++ b/nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py @@ -18,8 +18,8 @@ import torch from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import ( - build_slopes, build_relative_position, + build_slopes, ) __all__ = ['KERPLERelativePositionEmbedding'] @@ -67,7 +67,7 @@ def __init__( self.kerple_b = torch.nn.Parameter(build_slopes(num_attention_heads, num_attention_heads_kerple)) self.kerple_a = torch.zeros_like(self.kerple_b) self.kerple_p = torch.ones_like(self.kerple_b) - + # cache the relative position bias. shape (num_attention_heads, max_seq_len, max_seq_len) self.relative_position = build_relative_position(max_seq_len, max_seq_len, num_attention_heads) @@ -85,4 +85,4 @@ def forward(self, query_seq_length, key_seq_length): relative_position = torch.tril(relative_position) # shape (1, num_heads, query_length, key_length) - return - self.kerple_b * torch.log(1 + self.kerple_a * relative_position.unsqueeze(0).pow(self.kerple_p)) + return -self.kerple_b * torch.log(1 + self.kerple_a * relative_position.unsqueeze(0).pow(self.kerple_p))