Skip to content

Commit

Permalink
fix mobilebert
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 2, 2020
1 parent 4bc3a96 commit a84f782
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/gluonnlp/models/mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,11 @@ def __init__(self,
is_last_ffn = (ffn_idx == (num_stacked_ffn - 1))
# only apply dropout on last ffn layer if use bottleneck
dropout = float(hidden_dropout_prob * (not use_bottleneck) * is_last_ffn)
activation_dropout = float(activation_dropout_prob * (not use_bottleneck)
* is_last_ffn)
self.stacked_ffn.add(
PositionwiseFFN(units=real_units,
hidden_size=hidden_size,
dropout=dropout,
activation_dropout=activation_dropout,
activation_dropout=activation_dropout_prob,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
activation=activation,
Expand Down Expand Up @@ -343,7 +341,6 @@ def __init__(self,
num_heads=num_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
activation_dropout_prob=hidden_dropout_prob,
num_stacked_ffn=num_stacked_ffn,
bottleneck_strategy=bottleneck_strategy,
layer_norm_eps=layer_norm_eps,
Expand Down Expand Up @@ -540,7 +537,7 @@ def hybrid_forward(self, F, inputs, token_types, valid_length):
pooled_output :
This is optional. Shape (batch_size, units)
"""
embedding = self.get_initial_embedding(F, inputs, token_types)
embedding = self.get_initial_embedding(F, inputs, token_types, self.trigram_embed)

contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length)
outputs = []
Expand Down Expand Up @@ -961,8 +958,6 @@ def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert',
sha1_hash=FILE_STATS[mlm_params_path])
else:
local_mlm_params_path = None
do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
and PRETRAINED_URL[model_name]['lowercase'] else False
# TODO(sxjscience) Move do_lower to assets.
tokenizer = HuggingFaceWordPieceTokenizer(
vocab_file=local_paths['vocab'],
Expand All @@ -971,7 +966,7 @@ def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert',
cls_token='[CLS]',
sep_token='[SEP]',
mask_token='[MASK]',
lowercase=do_lower)
lowercase=True)
cfg = MobileBertModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_params_path, local_mlm_params_path

Expand Down

0 comments on commit a84f782

Please sign in to comment.