Skip to content

Commit

Permalink
reduce the size of the LSTM to try avoiding OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Jul 13, 2023
1 parent 09af3d1 commit 536e0de
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions delft/sequenceLabelling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=embedding_layer.shape[-1],
bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1],
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
bid_lstm = Dropout(config.dropout)(bid_lstm)
Expand Down Expand Up @@ -1239,7 +1239,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=embedding_layer.shape[-1],
bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1],
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
bid_lstm = Dropout(config.dropout)(bid_lstm)
Expand Down Expand Up @@ -1276,7 +1276,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=embedding_layer.shape[-1],
bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1],
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
bid_lstm = Dropout(config.dropout)(bid_lstm)
Expand Down

0 comments on commit 536e0de

Please sign in to comment.