Skip to content

Commit

Permalink
add BERT+BidLSTM and BERT+BidLSTM+CRF base models
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Jul 12, 2023
1 parent 30e87a6 commit d238ac5
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
4 changes: 3 additions & 1 deletion delft/applications/grobidTagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,9 @@ class Tasks:
word_embeddings_examples = ['glove-840B', 'fasttext-crawl', 'word2vec']

architectures_transformers_based = [
'BERT', 'BERT_FEATURES', 'BERT_CRF', 'BERT_ChainCRF', 'BERT_CRF_FEATURES', 'BERT_ChainCRF_FEATURES', 'BERT_CRF_CHAR', 'BERT_CRF_CHAR_FEATURES'
'BERT', 'BERT_FEATURES', 'BERT_CRF', 'BERT_ChainCRF', 'BERT_CRF_FEATURES', 'BERT_ChainCRF_FEATURES',
'BERT_CRF_CHAR', 'BERT_CRF_CHAR_FEATURES',
'BERT_BidLSTM', 'BERT_BidLSTM_CRF', 'BERT_BidLSTM_ChainCRF'
]

architectures = architectures_word_embeddings + architectures_transformers_based
Expand Down
129 changes: 128 additions & 1 deletion delft/sequenceLabelling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,30 @@ def get_model(config: ModelConfig, preprocessor, ntags=None, load_pretrained_wei
load_pretrained_weights=load_pretrained_weights,
local_path=local_path,
preprocessor=preprocessor)
elif config.architecture == BERT_BidLSTM.name:
preprocessor.return_bert_embeddings = True
config.labels = preprocessor.vocab_tag
return BERT_BidLSTM(config,
ntags,
load_pretrained_weights=load_pretrained_weights,
local_path=local_path,
preprocessor=preprocessor)
elif config.architecture == BERT_BidLSTM_CRF.name:
preprocessor.return_bert_embeddings = True
config.labels = preprocessor.vocab_tag
return BERT_BidLSTM_CRF(config,
ntags,
load_pretrained_weights=load_pretrained_weights,
local_path=local_path,
preprocessor=preprocessor)
elif config.architecture == BERT_BidLSTM_ChainCRF.name:
preprocessor.return_bert_embeddings = True
config.labels = preprocessor.vocab_tag
return BERT_BidLSTM_ChainCRF(config,
ntags,
load_pretrained_weights=load_pretrained_weights,
local_path=local_path,
preprocessor=preprocessor)
else:
raise (OSError('Model name does exist: ' + config.architecture))

Expand Down Expand Up @@ -1026,7 +1050,7 @@ def __init__(self, config, ntags=None, load_pretrained_weights=True, local_path:
self.crf = ChainCRF()
pred = self.crf(x)

self.model = Model(inputs=[input_ids_in, features_input, token_type_ids, attention_mask], outputs=[x])
self.model = Model(inputs=[input_ids_in, features_input, token_type_ids, attention_mask], outputs=[pred])
self.config = config

def get_generator(self):
Expand Down Expand Up @@ -1158,3 +1182,106 @@ def __init__(self, config, ntags=None, load_pretrained_weights=True, local_path:

def get_generator(self):
return DataGeneratorTransformers

class BERT_BidLSTM(BaseModel):
"""
"""

name = 'BERT_BidLSTM'

def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None):
super().__init__(config, ntags, load_pretrained_weights, local_path)

transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor)

input_ids_in = Input(shape=(None,), name='input_token', dtype='int32')
token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32')
attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32')

#embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0]
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units,
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
bid_lstm = Dropout(config.dropout)(bid_lstm)

label_logits = Dense(ntags, activation='softmax')(bid_lstm)

self.model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[label_logits])
self.config = config

def get_generator(self):
return DataGeneratorTransformers


class BERT_BidLSTM_CRF(BaseModel):
"""
"""

name = 'BERT_BidLSTM_CRF'

def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None):
super().__init__(config, ntags, load_pretrained_weights, local_path)

transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor)

input_ids_in = Input(shape=(None,), name='input_token', dtype='int32')
token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32')
attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32')

#embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0]
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units,
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
bid_lstm = Dropout(config.dropout)(bid_lstm)

base_model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[bid_lstm])

self.model = CRFModelWrapperForBERT(base_model, ntags)
self.model.build(input_shape=[(None, None, ), (None, None, ), (None, None, )])
self.config = config

def get_generator(self):
return DataGeneratorTransformers


class BERT_BidLSTM_ChainCRF(BaseModel):
"""
"""

name = 'BERT_BidLSTM_ChainCRF'

def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None):
super().__init__(config, ntags, load_pretrained_weights, local_path)

transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor)

input_ids_in = Input(shape=(None,), name='input_token', dtype='int32')
token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32')
attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32')

#embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0]
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units,
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
bid_lstm = Dropout(config.dropout)(bid_lstm)


self.crf = ChainCRF()
pred = self.crf(bid_lstm)

self.model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[pred])
self.config = config

def get_generator(self):
return DataGeneratorTransformers

0 comments on commit d238ac5

Please sign in to comment.