From 3b4f37af4626130fc4c9c5c09671a209d6e284c5 Mon Sep 17 00:00:00 2001 From: Sara Rabhi Date: Wed, 28 Jun 2023 17:19:01 -0400 Subject: [PATCH] Enable `rpe` methods in bert-like models (#6898) * enable rpe in bert model Signed-off-by: sararb * expose position_embedding_type to config Signed-off-by: sararb --------- Signed-off-by: sararb --- examples/nlp/language_modeling/conf/megatron_bert_config.yaml | 1 + .../nlp/models/language_modeling/megatron/bert_model.py | 2 ++ .../nlp/models/language_modeling/megatron_bert_model.py | 1 + 3 files changed, 4 insertions(+) diff --git a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml index a7e3364d41b4..4e53ded4a453 100644 --- a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml @@ -50,6 +50,7 @@ model: # model architecture encoder_seq_length: 512 max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. num_layers: 12 hidden_size: 768 ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py index 132f900298a6..cbbef2d56a15 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py @@ -188,6 +188,7 @@ def __init__( add_binary_head=True, megatron_legacy=False, sequence_parallel=False, + position_embedding_type='learned_absolute', ): super(BertModel, self).__init__() # args = get_args() @@ -234,6 +235,7 @@ def __init__( onnx_safe=onnx_safe, megatron_legacy=megatron_legacy, sequence_parallel=sequence_parallel, + position_embedding_type=position_embedding_type, ) self.initialize_word_embeddings( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index cac1a50e98ae..ab0459b2966c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -182,6 +182,7 @@ def model_provider_func(self, pre_process, post_process): add_binary_head=cfg.bert_binary_head, megatron_legacy=cfg.get('megatron_legacy', False), sequence_parallel=self.cfg.get('sequence_parallel', False), + position_embedding_type=self.cfg.get("position_embedding_type", "learned_absolute"), ) return model