diff --git a/delft/sequenceLabelling/models.py b/delft/sequenceLabelling/models.py index fd3e23af..a631d783 100644 --- a/delft/sequenceLabelling/models.py +++ b/delft/sequenceLabelling/models.py @@ -1205,7 +1205,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] embedding_layer = Dropout(0.1)(embedding_layer) - bid_lstm = Bidirectional(LSTM(units=embedding_layer.shape[-1], + bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1], return_sequences=True, recurrent_dropout=config.recurrent_dropout))(embedding_layer) bid_lstm = Dropout(config.dropout)(bid_lstm) @@ -1239,7 +1239,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] embedding_layer = Dropout(0.1)(embedding_layer) - bid_lstm = Bidirectional(LSTM(units=embedding_layer.shape[-1], + bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1], return_sequences=True, recurrent_dropout=config.recurrent_dropout))(embedding_layer) bid_lstm = Dropout(config.dropout)(bid_lstm) @@ -1276,7 +1276,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] embedding_layer = Dropout(0.1)(embedding_layer) - bid_lstm = Bidirectional(LSTM(units=embedding_layer.shape[-1], + bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1], return_sequences=True, recurrent_dropout=config.recurrent_dropout))(embedding_layer) bid_lstm = Dropout(config.dropout)(bid_lstm)