diff --git a/delft/applications/grobidTagger.py b/delft/applications/grobidTagger.py index 3c29651b..cfd7f355 100644 --- a/delft/applications/grobidTagger.py +++ b/delft/applications/grobidTagger.py @@ -349,7 +349,9 @@ class Tasks: word_embeddings_examples = ['glove-840B', 'fasttext-crawl', 'word2vec'] architectures_transformers_based = [ - 'BERT', 'BERT_FEATURES', 'BERT_CRF', 'BERT_ChainCRF', 'BERT_CRF_FEATURES', 'BERT_ChainCRF_FEATURES', 'BERT_CRF_CHAR', 'BERT_CRF_CHAR_FEATURES' + 'BERT', 'BERT_FEATURES', 'BERT_CRF', 'BERT_ChainCRF', 'BERT_CRF_FEATURES', 'BERT_ChainCRF_FEATURES', + 'BERT_CRF_CHAR', 'BERT_CRF_CHAR_FEATURES', + 'BERT_BidLSTM', 'BERT_BidLSTM_CRF', 'BERT_BidLSTM_ChainCRF' ] architectures = architectures_word_embeddings + architectures_transformers_based diff --git a/delft/sequenceLabelling/models.py b/delft/sequenceLabelling/models.py index e54d34e7..f30f7f51 100644 --- a/delft/sequenceLabelling/models.py +++ b/delft/sequenceLabelling/models.py @@ -182,6 +182,30 @@ def get_model(config: ModelConfig, preprocessor, ntags=None, load_pretrained_wei load_pretrained_weights=load_pretrained_weights, local_path=local_path, preprocessor=preprocessor) + elif config.architecture == BERT_BidLSTM.name: + preprocessor.return_bert_embeddings = True + config.labels = preprocessor.vocab_tag + return BERT_BidLSTM(config, + ntags, + load_pretrained_weights=load_pretrained_weights, + local_path=local_path, + preprocessor=preprocessor) + elif config.architecture == BERT_BidLSTM_CRF.name: + preprocessor.return_bert_embeddings = True + config.labels = preprocessor.vocab_tag + return BERT_BidLSTM_CRF(config, + ntags, + load_pretrained_weights=load_pretrained_weights, + local_path=local_path, + preprocessor=preprocessor) + elif config.architecture == BERT_BidLSTM_ChainCRF.name: + preprocessor.return_bert_embeddings = True + config.labels = preprocessor.vocab_tag + return BERT_BidLSTM_ChainCRF(config, + ntags, + load_pretrained_weights=load_pretrained_weights, + local_path=local_path, + preprocessor=preprocessor) else: raise (OSError('Model name does exist: ' + config.architecture)) @@ -1026,7 +1050,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights=True, local_path: self.crf = ChainCRF() pred = self.crf(x) - self.model = Model(inputs=[input_ids_in, features_input, token_type_ids, attention_mask], outputs=[x]) + self.model = Model(inputs=[input_ids_in, features_input, token_type_ids, attention_mask], outputs=[pred]) self.config = config def get_generator(self): @@ -1158,3 +1182,106 @@ def __init__(self, config, ntags=None, load_pretrained_weights=True, local_path: def get_generator(self): return DataGeneratorTransformers + +class BERT_BidLSTM(BaseModel): + """ + """ + + name = 'BERT_BidLSTM' + + def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None): + super().__init__(config, ntags, load_pretrained_weights, local_path) + + transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor) + + input_ids_in = Input(shape=(None,), name='input_token', dtype='int32') + token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32') + attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32') + + #embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0] + embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] + embedding_layer = Dropout(0.1)(embedding_layer) + + bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, + return_sequences=True, + recurrent_dropout=config.recurrent_dropout))(embedding_layer) + bid_lstm = Dropout(config.dropout)(bid_lstm) + + label_logits = Dense(ntags, activation='softmax')(bid_lstm) + + self.model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[label_logits]) + self.config = config + + def get_generator(self): + return DataGeneratorTransformers + + +class BERT_BidLSTM_CRF(BaseModel): + """ + + """ + + name = 'BERT_BidLSTM_CRF' + + def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None): + super().__init__(config, ntags, load_pretrained_weights, local_path) + + transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor) + + input_ids_in = Input(shape=(None,), name='input_token', dtype='int32') + token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32') + attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32') + + #embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0] + embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] + embedding_layer = Dropout(0.1)(embedding_layer) + + bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, + return_sequences=True, + recurrent_dropout=config.recurrent_dropout))(embedding_layer) + bid_lstm = Dropout(config.dropout)(bid_lstm) + + base_model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[bid_lstm]) + + self.model = CRFModelWrapperForBERT(base_model, ntags) + self.model.build(input_shape=[(None, None, ), (None, None, ), (None, None, )]) + self.config = config + + def get_generator(self): + return DataGeneratorTransformers + + +class BERT_BidLSTM_ChainCRF(BaseModel): + """ + + """ + + name = 'BERT_BidLSTM_ChainCRF' + + def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None): + super().__init__(config, ntags, load_pretrained_weights, local_path) + + transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor) + + input_ids_in = Input(shape=(None,), name='input_token', dtype='int32') + token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32') + attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32') + + #embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0] + embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] + embedding_layer = Dropout(0.1)(embedding_layer) + + bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, + return_sequences=True, + recurrent_dropout=config.recurrent_dropout))(embedding_layer) + bid_lstm = Dropout(config.dropout)(bid_lstm) + + + self.crf = ChainCRF() + pred = self.crf(bid_lstm) + + self.model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[pred]) + self.config = config + + def get_generator(self): + return DataGeneratorTransformers \ No newline at end of file