Skip to content

Commit

Permalink
freze bert and concatenate embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Aug 9, 2023
1 parent 24dcfed commit 78c8054
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 16 deletions.
50 changes: 37 additions & 13 deletions delft/sequenceLabelling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,11 +1187,41 @@ def get_generator(self):
return DataGeneratorTransformers

class BERT_BidLSTM(BaseModel):
"""
"""

name = 'BERT_BidLSTM'

def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None):
super().__init__(config, ntags, load_pretrained_weights, local_path)

transformer_layers = self.init_transformer(config, load_pretrained_weights, local_path, preprocessor)

input_ids_in = Input(shape=(None,), name='input_token', dtype='int32')
token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32')
attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32')

embedding_layers = transformer_layers(input_ids_in,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
training=False)[-4:]
concatenated_embeddings = Concatenate([layer for layer in embedding_layers])

bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units,
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(concatenated_embeddings)
bid_lstm = Dropout(config.dropout)(bid_lstm)

label_logits = Dense(ntags, activation='softmax')(bid_lstm)

self.model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[label_logits])
self.config = config

def get_generator(self):
return DataGeneratorTransformers


class BidLSTM_BERT(BaseModel):
name = 'BidLSTM_BERT'

def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None):
super().__init__(config, ntags, load_pretrained_weights, local_path)

Expand Down Expand Up @@ -1220,10 +1250,6 @@ def get_generator(self):


class BERT_BidLSTM_CRF(BaseModel):
"""
"""

name = 'BERT_BidLSTM_CRF'

def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, local_path: str = None, preprocessor=None):
Expand All @@ -1235,15 +1261,13 @@ def __init__(self, config, ntags=None, load_pretrained_weights: bool = True, loc
token_type_ids = Input(shape=(None,), name='input_token_type', dtype='int32')
attention_mask = Input(shape=(None,), name='input_attention_mask', dtype='int32')

#embedding_layer = transformer_model(input_ids_in, token_type_ids=token_type_ids)[0]
embedding_layer = transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
embedding_layer = Dropout(0.1)(embedding_layer)

bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units, #embedding_layer.shape[-1],
concatenated_embeddings = Concatenate()([layer for layer in transformer_layers(input_ids_in, token_type_ids=token_type_ids, attention_mask=attention_mask,
training=False)[-4:]])
bid_lstm = Bidirectional(LSTM(units=config.num_word_lstm_units,
return_sequences=True,
recurrent_dropout=config.recurrent_dropout))(embedding_layer)
recurrent_dropout=config.recurrent_dropout))(concatenated_embeddings)
bid_lstm = Dropout(config.dropout)(bid_lstm)
bid_lstm = Dense(embedding_layer.shape[-1], activation='tanh')(bid_lstm)
bid_lstm = Dense(concatenated_embeddings.shape[-1], activation='tanh')(bid_lstm)

base_model = Model(inputs=[input_ids_in, token_type_ids, attention_mask], outputs=[bid_lstm])

Expand Down
66 changes: 63 additions & 3 deletions delft/utilities/Embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from tqdm import tqdm
from pathlib import Path

from delft.sequenceLabelling.config import ModelConfig
from delft.sequenceLabelling.preprocess import BERTPreprocessor, Preprocessor
from delft.utilities.Transformer import Transformer
from delft.utilities.simple_elmo import ElmoModel, elmo

logging.basicConfig()
Expand Down Expand Up @@ -81,7 +84,7 @@ def __init__(self, name,
# below init for using ELMo embeddings
self.use_ELMo = use_ELMo
self.elmo_model_name = elmo_model_name
if elmo_model_name == None:
if elmo_model_name is None:
self.elmo_model_name = 'elmo-'+self.lang
if use_ELMo:
#tf.compat.v1.disable_eager_execution()
Expand Down Expand Up @@ -479,12 +482,12 @@ def get_elmo_embedding_path(self, description):
destination_dir = os.path.join("data/models/ELMo", self.elmo_model_name)
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
try:
try:
shutil.move(embeddings_path, destination_file)
weights_file = destination_file
except OSError:
print ("Copy of ELMo weights file to ELMo directory path", destination_file, "failed")

if "url_weights" not in description or description["url_weights"] == None or len(description["url_weights"]) == 0:
print("no download url available for this ELMo model weights embeddings resource, please review the embedding registry for", name)
print("ELMo weights used:", weights_file)
Expand Down Expand Up @@ -761,3 +764,60 @@ def load_resource_registry(path='delft/resources-registry.json'):
"""
registry_json = open(path).read()
return json.loads(registry_json)

class ContextualizedEmbeddings(Embeddings):

def __init__(self, transformer_name: str, registry: dict, max_sequence_length: int):

super().__init__(transformer_name, use_cache=False, use_ELMo=False, resource_registry=registry, load=False)
self.embed_size = 768
self.transformer = Transformer(transformer_name, registry)
self.model = self.transformer.instantiate_layer(load_pretrained_weights=True, output_hidden_states=True)
self.transformer_config = self.transformer.transformer_config
self.transformer.init_preprocessor(max_sequence_length=max_sequence_length)
self.preprocessor = BERTPreprocessor(self.transformer.tokenizer)
# self.transformer.tokenizer.empty_features_vector())
# preprocessor.empty_char_vector())


# def get_sentence_vectors(self, token_list):
# token_vecs = hidden_states[-2][0]
#
# # Calculate the average of all 22 token vectors.
# sentence_embedding = torch.mean(token_vecs, dim=0)
#

def get_sentence_vector(self, text_tokens):
(target_ids, target_type_ids, target_attention_mask, target_chars,
target_features, target_labels, input_tokens) = self.preprocessor.tokenize_and_align_features_and_labels(text_tokens)

self.model.eval()
# segments_ids = [1] * len(target_ids)
outputs = self.model(target_ids, target_type_ids)
hidden_states = outputs[2]

# tokens, batches, vector size
token_embeddings = tf.stack(hidden_states, axis=0)

# layers, tokens, batches, vector size
token_embeddings = tf.squeeze(token_embeddings, axis=1)

# layers, tokens, vector size
token_embeddings = tf.transpose(token_embeddings, perm=[1, 0, 2])

# layers, tokens, vector size
token_vecs_cat = []

for token in token_embeddings:
cat_vec = tf.concat((token[-1], token[-2], token[-3], token[-4]), dim=0)
token_vecs_cat.append(cat_vec)

# Sum
# sum_vec = torch.sum(token[-4:], dim=0)
# token_vecs_sum.append(sum_vec)

print('Shape is: %d x %d' % (len(token_vecs_cat), len(token_vecs_cat[0])))

return token_vecs_cat


0 comments on commit 78c8054

Please sign in to comment.