Skip to content

Commit

Permalink
Enable rpe methods in bert-like models (#6898)
Browse files Browse the repository at this point in the history
* enable rpe in bert model

Signed-off-by: sararb <sara.rabhi@gmail.com>

* expose position_embedding_type to config

Signed-off-by: sararb <sara.rabhi@gmail.com>

---------

Signed-off-by: sararb <sara.rabhi@gmail.com>
  • Loading branch information
sararb authored Jun 28, 2023
1 parent 7e20750 commit 3b4f37a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ model:
# model architecture
encoder_seq_length: 512
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental.
num_layers: 12
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
add_binary_head=True,
megatron_legacy=False,
sequence_parallel=False,
position_embedding_type='learned_absolute',
):
super(BertModel, self).__init__()
# args = get_args()
Expand Down Expand Up @@ -234,6 +235,7 @@ def __init__(
onnx_safe=onnx_safe,
megatron_legacy=megatron_legacy,
sequence_parallel=sequence_parallel,
position_embedding_type=position_embedding_type,
)

self.initialize_word_embeddings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def model_provider_func(self, pre_process, post_process):
add_binary_head=cfg.bert_binary_head,
megatron_legacy=cfg.get('megatron_legacy', False),
sequence_parallel=self.cfg.get('sequence_parallel', False),
position_embedding_type=self.cfg.get("position_embedding_type", "learned_absolute"),
)

return model
Expand Down

0 comments on commit 3b4f37a

Please sign in to comment.