From f47ac99eaebfd16e45281d094884bdecc566f507 Mon Sep 17 00:00:00 2001 From: Kohulan Date: Mon, 31 Jul 2023 10:54:17 +0200 Subject: [PATCH] fix: use checkpoints properly for predictions #65 and #66 --- DECIMER/Efficient_Net_encoder.py | 30 +- DECIMER/Predictor_EfficientNet2.py | 160 ----------- DECIMER/Predictor_usingCheckpoints.py | 197 +++++++++++++ DECIMER/Repack_model.py | 15 +- DECIMER/Transformer_decoder.py | 390 ++++++-------------------- DECIMER/config.py | 24 +- DECIMER/efficientnetv2/utils.py | 103 +++++-- 7 files changed, 375 insertions(+), 544 deletions(-) delete mode 100644 DECIMER/Predictor_EfficientNet2.py create mode 100644 DECIMER/Predictor_usingCheckpoints.py diff --git a/DECIMER/Efficient_Net_encoder.py b/DECIMER/Efficient_Net_encoder.py index 7f7d658..0fc6684 100644 --- a/DECIMER/Efficient_Net_encoder.py +++ b/DECIMER/Efficient_Net_encoder.py @@ -1,30 +1,17 @@ # EfficientNet-V2 config import tensorflow as tf -import DECIMER.efficientnetv2 as efficientnetv2 +import efficientnetv2 +from efficientnetv2 import effnetv2_model +from efficientnetv2 import effnetv2_configs BATCH_SIZE_DEBUG = 2 -MODEL = "efficientnetv2-b3" # @param +MODEL = "efficientnetv2-m" # @param # Define encoder def get_efficientnetv2_backbone( - model_name, include_top=False, input_shape=(299, 299, 3), pooling=None, weights=None + model_name, include_top=False, input_shape=(512, 512, 3), pooling=None, weights=None ): - """Initiate and get the desired Efficient-Net V2 backbone as encoder - - Args: - model_name (str): Name of the Efficient-Net V2 model - include_top (bool, optional): Defaults to False. - input_shape (tuple, optional): Image shape. Defaults to (299, 299, 3). - pooling (int, optional): Max pooling values. Defaults to None. - weights ( optional): Pretrained weights. Defaults to None. - - Raises: - NotImplementedError: At this time we only want to use the raw - - Returns: - Efficient Net V2 backbone - """ # Catch unsupported arguments if pooling or weights or include_top: raise NotImplementedError( @@ -39,12 +26,6 @@ def get_efficientnetv2_backbone( class Encoder(tf.keras.Model): - """Encoder class - - Args: - tf (_type_): tensorflow model module - """ - def __init__( self, image_embedding_dim, @@ -56,6 +37,7 @@ def __init__( pretrained_weights=None, scale_factor=0, ): + super(Encoder, self).__init__() self.image_embedding_dim = image_embedding_dim diff --git a/DECIMER/Predictor_EfficientNet2.py b/DECIMER/Predictor_EfficientNet2.py deleted file mode 100644 index 6265c6c..0000000 --- a/DECIMER/Predictor_EfficientNet2.py +++ /dev/null @@ -1,160 +0,0 @@ -import os -import sys -import tensorflow as tf - -import pickle -from selfies import decoder -import Transformer_decoder -import Efficient_Net_encoder -import config - -# Set GPU -os.environ["CUDA_VISIBLE_DEVICES"] = "2" -gpus = tf.config.experimental.list_physical_devices("GPU") -for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - -# load assets -HERE = os.path.dirname(os.path.abspath(__file__)) -tokenizer = pickle.load( - open(os.path.join(HERE, "tokenizer_Isomeric_SELFIES.pkl"), "rb") -) -max_length = pickle.load( - open(os.path.join(HERE, "max_length_Isomeric_SELFIES.pkl"), "rb") -) - -# Image partameters -IMG_EMB_DIM = (10, 10, 232) -IMG_EMB_DIM = (IMG_EMB_DIM[0] * IMG_EMB_DIM[1], IMG_EMB_DIM[2]) -IMG_SHAPE = (299, 299, 3) -PE_INPUT = IMG_EMB_DIM[0] -IMG_SEQ_LEN, IMG_EMB_DEPTH = IMG_EMB_DIM -D_MODEL = IMG_EMB_DEPTH - -# Network parameters -N_LAYERS = 4 -D_MODEL = 512 -D_FF = 2048 -N_HEADS = 8 -DROPOUT_RATE = 0.1 - -# Misc -MAX_LEN = max_length -VOCAB_LEN = len(tokenizer.word_index) -PE_OUTPUT = MAX_LEN -TARGET_V_SIZE = VOCAB_LEN -REPLICA_BATCH_SIZE = 1 - -# Config Encoder -PREPROCESSING_FN = tf.keras.applications.efficientnet.preprocess_input -BB_FN = Efficient_Net_encoder.get_efficientnetv2_backbone - -# Config Model -testing_config = config.Config() - -testing_config.initialize_encoder_config( - image_embedding_dim=IMG_EMB_DIM, - preprocessing_fn=PREPROCESSING_FN, - backbone_fn=BB_FN, - image_shape=IMG_SHAPE, - do_permute=IMG_EMB_DIM[1] < IMG_EMB_DIM[0], -) - -testing_config.initialize_transformer_config( - vocab_len=VOCAB_LEN, - max_len=MAX_LEN, - n_transformer_layers=N_LAYERS, - transformer_d_dff=D_FF, - transformer_n_heads=N_HEADS, - image_embedding_dim=IMG_EMB_DIM, -) - -# print(f"Encoder config:\n\t -> {testing_config.encoder_config}\n") -# print(f"Transformer config:\n\t -> {testing_config.transformer_config}\n") - -# Prepare model -optimizer, encoder, transformer = config.prepare_models( - encoder_config=testing_config.encoder_config, - transformer_config=testing_config.transformer_config, - replica_batch_size=REPLICA_BATCH_SIZE, - verbose=0, -) - -# Load trained model checkpoint -checkpoint_path = os.path.join(HERE, "checkpoints_") -ckpt = tf.train.Checkpoint( - encoder=encoder, transformer=transformer, optimizer=optimizer -) -ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=50) - -start_epoch = 0 -if ckpt_manager.latest_checkpoint: - ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)) - start_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1]) - - -def main(): - if len(sys.argv) != 2: - print("Usage: {} $image_path".format(sys.argv[0])) - else: - SMILES = predict_SMILES(sys.argv[1]) - print(SMILES) - - -def evaluate(image_path: str): - """ - This function takes an image path (str) and returns the SELFIES - representation of the depicted molecule (str). - - Args: - image_path (str): Path of chemical structure depiction image - - Returns: - (str): SELFIES representation of the molecule in the input image - """ - sample = config.decode_image(image_path) - _image_batch = tf.expand_dims(sample, 0) - _image_embedding = encoder(_image_batch, training=False) - output = tf.expand_dims([tokenizer.word_index[""]], 0) - result = [] - end_token = tokenizer.word_index[""] - - for i in range(MAX_LEN): - combined_mask = Transformer_decoder.create_mask(None, output) - prediction_batch, _ = transformer( - _image_embedding, output, training=False, look_ahead_mask=combined_mask - ) - - predictions = prediction_batch[:, -1:, :] - predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) - - if predicted_id == end_token: - return result - - result.append(tokenizer.index_word[int(predicted_id)]) - output = tf.concat([output, predicted_id], axis=-1) - - return result - - -def predict_SMILES(image_path: str): - """ - This function takes an image path (str) and returns the SMILES - representation of the depicted molecule (str). - - Args: - image_path (str): Path of chemical structure depiction image - - Returns: - (str): SMILES representation of the molecule in the input image - """ - predicted_SELFIES = evaluate(image_path) - predicted_SMILES = decoder( - "".join(predicted_SELFIES).replace("", "").replace("", "") - ) - - return predicted_SMILES - - -if __name__ == "__main__": - main() diff --git a/DECIMER/Predictor_usingCheckpoints.py b/DECIMER/Predictor_usingCheckpoints.py new file mode 100644 index 0000000..e2ecd5e --- /dev/null +++ b/DECIMER/Predictor_usingCheckpoints.py @@ -0,0 +1,197 @@ +import os +import sys +import tensorflow as tf + +import pickle +import pystow +from selfies import decoder +import Transformer_decoder +import Efficient_Net_encoder +import config +import utils + +print(tf.__version__) + +# Set GPU +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +gpus = tf.config.experimental.list_physical_devices("GPU") +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + +# Set path +default_path = pystow.join("DECIMER-V2") + +# model download location +checkpoint_url = "https://zenodo.org/record/8093783/files/DECIMER_512_checkpoints.zip" +checkpoint_path = str(default_path) + "/DECIMER_checkpoints/" + +if not os.path.exists(checkpoint_path): + config.download_trained_weights(checkpoint_url, default_path) + +# load assets + +tokenizer = pickle.load( + open( + default_path.as_posix() + "/DECIMER_model/assets/tokenizer_SMILES.pkl", + "rb", + ) +) + +max_length = 302 + +# Image partameters +IMG_EMB_DIM = (16, 16, 512) +IMG_EMB_DIM = (IMG_EMB_DIM[0] * IMG_EMB_DIM[1], IMG_EMB_DIM[2]) +IMG_SHAPE = (512, 512, 3) +PE_INPUT = IMG_EMB_DIM[0] +IMG_SEQ_LEN, IMG_EMB_DEPTH = IMG_EMB_DIM +D_MODEL = 512 + +# Network parameters +N_LAYERS = 6 +D_MODEL = 512 +D_FF = 2048 +N_HEADS = 8 +DROPOUT_RATE = 0.1 + +# Misc +MAX_LEN = max_length +VOCAB_LEN = len(tokenizer.word_index) +PE_OUTPUT = MAX_LEN +TARGET_V_SIZE = VOCAB_LEN +REPLICA_BATCH_SIZE = 1 + +# Config Encoder +PREPROCESSING_FN = tf.keras.applications.efficientnet.preprocess_input +BB_FN = Efficient_Net_encoder.get_efficientnetv2_backbone + +# Config Model +testing_config = config.Config() + +testing_config.initialize_encoder_config( + image_embedding_dim=IMG_EMB_DIM, + preprocessing_fn=PREPROCESSING_FN, + backbone_fn=BB_FN, + image_shape=IMG_SHAPE, + do_permute=IMG_EMB_DIM[1] < IMG_EMB_DIM[0], +) +testing_config.initialize_transformer_config( + vocab_len=VOCAB_LEN, + max_len=MAX_LEN, + n_transformer_layers=N_LAYERS, + transformer_d_dff=D_FF, + transformer_n_heads=N_HEADS, + image_embedding_dim=D_MODEL, +) + +# print(f"Encoder config:\n\t -> {testing_config.encoder_config}\n") +# print(f"Transformer config:\n\t -> {testing_config.transformer_config}\n") + +# Prepare model +optimizer, encoder, transformer = config.prepare_models( + encoder_config=testing_config.encoder_config, + transformer_config=testing_config.transformer_config, + replica_batch_size=REPLICA_BATCH_SIZE, + verbose=0, +) + +# Load trained model checkpoint +ckpt = tf.train.Checkpoint( + encoder=encoder, transformer=transformer, optimizer=optimizer +) +ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=50) +if ckpt_manager.latest_checkpoint: + ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)) + start_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1]) + + +def main(): + if len(sys.argv) != 2: + print("Usage: {} $image_path".format(sys.argv[0])) + else: + SMILES = predict_SMILES(sys.argv[1]) + print(SMILES) + + +class DECIMER_Predictor(tf.Module): + def __init__(self, encoder, tokenizer, transformer, max_length): + self.encoder = encoder + self.tokenizer = tokenizer + self.transformer = transformer + self.max_length = max_length + + def __call__(self, Decoded_image): + assert isinstance(Decoded_image, tf.Tensor) + if len(Decoded_image.shape) == 0: + sentence = Decoded_image[tf.newaxis] + + _image_batch = tf.expand_dims(Decoded_image, 0) + _image_embedding = encoder(_image_batch, training=False) + + start_token = tf.cast( + tf.convert_to_tensor([tokenizer.word_index[""]]), tf.int32 + ) + end_token = tf.cast( + tf.convert_to_tensor([tokenizer.word_index[""]]), tf.int32 + ) + + output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) + output_array = output_array.write(0, start_token) + + for t in tf.range(max_length): + output = tf.transpose(output_array.stack()) + combined_mask = Transformer_decoder.create_masks_decoder(output) + + # predictions.shape == (batch_size, seq_len, vocab_size) + prediction_batch = transformer( + output, _image_embedding, training=False, look_ahead_mask=combined_mask + ) + + # select the last word from the seq_len dimension + predictions = prediction_batch[:, -1:, :] # (batch_size, 1, vocab_size) + + predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) + + output_array = output_array.write(t + 1, predicted_id[0]) + + if predicted_id == end_token: + break + output = tf.transpose(output_array.stack()) + return output + + +def detokenize_output(predicted_array): + outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()] + prediction = ( + "".join([str(elem) for elem in outputs]) + .replace("", "") + .replace("", "") + ) + + return prediction + + +# Initiate the DECIMER class +DECIMER = DECIMER_Predictor(encoder, tokenizer, transformer, MAX_LEN) + + +def predict_SMILES(image_path: str): + """ + This function takes an image path (str) and returns the SMILES + representation of the depicted molecule (str). + + Args: + image_path (str): Path of chemical structure depiction image + + Returns: + (str): SMILES representation of the molecule in the input image + """ + decodedImage = config.decode_image(image_path) + predicted_tokens = DECIMER(tf.constant(decodedImage)) + predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens)) + + return predicted_SMILES + + +if __name__ == "__main__": + main() diff --git a/DECIMER/Repack_model.py b/DECIMER/Repack_model.py index 1206aa7..f6554c0 100644 --- a/DECIMER/Repack_model.py +++ b/DECIMER/Repack_model.py @@ -20,15 +20,15 @@ max_length = pickle.load(open("max_length.pkl", "rb")) # Image parameters -IMG_EMB_DIM = (16, 16, 232) +IMG_EMB_DIM = (16, 16, 512) IMG_EMB_DIM = (IMG_EMB_DIM[0] * IMG_EMB_DIM[1], IMG_EMB_DIM[2]) IMG_SHAPE = (512, 512, 3) PE_INPUT = IMG_EMB_DIM[0] IMG_SEQ_LEN, IMG_EMB_DEPTH = IMG_EMB_DIM -D_MODEL = IMG_EMB_DEPTH +D_MODEL = 512 # Network parameters -N_LAYERS = 4 +N_LAYERS = 6 D_MODEL = 512 D_FF = 2048 N_HEADS = 8 @@ -56,14 +56,13 @@ image_shape=IMG_SHAPE, do_permute=IMG_EMB_DIM[1] < IMG_EMB_DIM[0], ) - testing_config.initialize_transformer_config( vocab_len=VOCAB_LEN, max_len=MAX_LEN, n_transformer_layers=N_LAYERS, transformer_d_dff=D_FF, transformer_n_heads=N_HEADS, - image_embedding_dim=IMG_EMB_DIM, + image_embedding_dim=D_MODEL, ) # Prepare model @@ -147,11 +146,11 @@ def __call__(self, Decoded_image): for t in tf.range(max_length): output = tf.transpose(output_array.stack()) - combined_mask = Transformer_decoder.create_mask(None, output) + combined_mask = Transformer_decoder.create_masks_decoder(output) # predictions.shape == (batch_size, seq_len, vocab_size) - prediction_batch, _ = transformer( - _image_embedding, output, training=False, look_ahead_mask=combined_mask + prediction_batch = transformer( + output, _image_embedding, training=False, look_ahead_mask=combined_mask ) # select the last word from the seq_len dimension diff --git a/DECIMER/Transformer_decoder.py b/DECIMER/Transformer_decoder.py index a3d0b8e..7e91e29 100644 --- a/DECIMER/Transformer_decoder.py +++ b/DECIMER/Transformer_decoder.py @@ -6,50 +6,32 @@ def get_angles(pos, i, d_model): - angle_rates = tf.constant(1, TARGET_DTYPE) / tf.math.pow( - tf.constant(10000, TARGET_DTYPE), - (tf.constant(2, dtype=TARGET_DTYPE) * tf.cast((i // 2), TARGET_DTYPE)) - / d_model, - ) + angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) return pos * angle_rates -def do_interleave(arr_a, arr_b): - a_arr_tf_column = tf.range(arr_a.shape[1]) * 2 # [0 2 4 ...] - b_arr_tf_column = tf.range(arr_b.shape[1]) * 2 + 1 # [1 3 5 ...] - column_indices = tf.argsort(tf.concat([a_arr_tf_column, b_arr_tf_column], axis=-1)) - column, row = tf.meshgrid(column_indices, tf.range(arr_a.shape[0])) - combine_indices = tf.stack([row, column], axis=-1) - combine_value = tf.concat([arr_a, arr_b], axis=1) - return tf.gather_nd(combine_value, combine_indices) - - def positional_encoding_1d(position, d_model): angle_rads = get_angles( - tf.cast(tf.range(position)[:, tf.newaxis], TARGET_DTYPE), - tf.cast(tf.range(d_model)[tf.newaxis, :], TARGET_DTYPE), - d_model, + np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model ) - # apply sin to even indices in the array; 2i - sin_angle_rads = tf.math.sin(angle_rads[:, ::2]) - cos_angle_rads = tf.math.cos(angle_rads[:, 1::2]) - angle_rads = do_interleave(sin_angle_rads, cos_angle_rads) - pos_encoding = angle_rads[tf.newaxis, ...] - return pos_encoding + angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) + angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) + pos_encoding = angle_rads[np.newaxis, ...] + return tf.cast(pos_encoding, dtype=TARGET_DTYPE) -def np_positional_encoding_2d(row, col, d_model): +def positional_encoding_2d(row, col, d_model): assert d_model % 2 == 0 row_pos = np.repeat(np.arange(row), col)[:, np.newaxis] col_pos = np.repeat(np.expand_dims(np.arange(col), 0), row, axis=0).reshape(-1, 1) angle_rads_row = get_angles( row_pos, np.arange(d_model // 2)[np.newaxis, :], d_model // 2 - ).numpy() + ) angle_rads_col = get_angles( col_pos, np.arange(d_model // 2)[np.newaxis, :], d_model // 2 - ).numpy() + ) angle_rads_row[:, 0::2] = np.sin(angle_rads_row[:, 0::2]) angle_rads_row[:, 1::2] = np.cos(angle_rads_row[:, 1::2]) @@ -61,378 +43,166 @@ def np_positional_encoding_2d(row, col, d_model): return tf.cast(pos_encoding, dtype=TARGET_DTYPE) -def positional_encoding_2d(row, col, d_model): - row_pos = tf.repeat(tf.range(row), col)[:, tf.newaxis] - col_pos = tf.reshape( - tf.repeat(tf.expand_dims(tf.range(col), 0), row, axis=0), (-1, 1) - ) - - angle_rads_row = get_angles( - tf.cast(row_pos, tf.float32), - tf.range(d_model // 2)[tf.newaxis, :], - d_model // 2, - ) - angle_rads_col = get_angles( - tf.cast(col_pos, tf.float32), - tf.range(d_model // 2)[tf.newaxis, :], - d_model // 2, - ) - - sin_angle_rads_row = tf.math.sin(angle_rads_row[:, ::2]) - cos_angle_rads_row = tf.math.cos(angle_rads_row[:, 1::2]) - angle_rads_row = do_interleave(sin_angle_rads_row, cos_angle_rads_row) - - sin_angle_rads_col = tf.math.sin(angle_rads_col[:, ::2]) - cos_angle_rads_col = tf.math.cos(angle_rads_col[:, 1::2]) - angle_rads_col = do_interleave(sin_angle_rads_col, cos_angle_rads_col) - - pos_encoding = tf.concat([angle_rads_row, angle_rads_col], axis=1)[tf.newaxis, ...] - return pos_encoding - - def create_padding_mask(seq): seq = tf.cast(tf.math.equal(seq, 0), TARGET_DTYPE) - - # add extra dimensions to add the padding to the attention logits. - # - (batch_size, 1, 1, seq_len) - return seq[:, tf.newaxis, tf.newaxis, :] + return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) def create_look_ahead_mask(size): mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) - # (seq_len, seq_len) - return tf.cast(mask, TARGET_DTYPE) - - -def create_mask(inp, tar): - # Used in the 1st attention block in the decoder. - # It is used to pad and mask future tokens in the input received by - # the decoder. - look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) - dec_target_padding_mask = create_padding_mask(tar) - combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) - return tf.cast(combined_mask, TARGET_DTYPE) + mask = tf.cast(mask, TARGET_DTYPE) + return mask # (seq_len, seq_len) def scaled_dot_product_attention(q, k, v, mask): - matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) - - # scale matmul_qk + # (..., seq_len_q, seq_len_k) + matmul_qk = tf.matmul(q, k, transpose_b=True) dk = tf.cast(tf.shape(k)[-1], TARGET_DTYPE) - - # Calculate scaled attention logits scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) - # add the mask to the scaled tensor. if mask is not None: scaled_attention_logits += mask * -1e9 - # softmax is normalized on the last axis (seq_len_k) - # so that the scores add up to 1. - # - shape --> (..., seq_len_q, seq_len_k) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) - - # - shape --> (..., seq_len_q, depth_v) - output = tf.matmul(attention_weights, v) + output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) return output, attention_weights +def create_masks_decoder(tar): + look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) + dec_target_padding_mask = create_padding_mask(tar) + combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) + return combined_mask + + +def point_wise_feed_forward_network(d_model, dff): + return tf.keras.Sequential( + [tf.keras.layers.Dense(dff, activation="relu"), tf.keras.layers.Dense(d_model)] + ) + + class MultiHeadAttention(tf.keras.layers.Layer): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model - assert d_model % self.num_heads == 0 - self.depth = d_model // self.num_heads - self.wq = tf.keras.layers.Dense(d_model) self.wk = tf.keras.layers.Dense(d_model) self.wv = tf.keras.layers.Dense(d_model) - self.dense = tf.keras.layers.Dense(d_model) def split_heads(self, x, batch_size): x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) return tf.transpose(x, perm=[0, 2, 1, 3]) - def call(self, v, k, q, mask): + def call(self, q, k, v, q_pos=None, k_pos=None, mask=None): batch_size = tf.shape(q)[0] - - # (batch_size, seq_len, d_model) q = self.wq(q) - # (batch_size, seq_len, d_model) k = self.wk(k) - # (batch_size, seq_len, d_model) v = self.wv(v) - # (batch_size, num_heads, seq_len_q, depth) + if q_pos is not None: + q = q + q_pos + if k_pos is not None: + k = k + k_pos + q = self.split_heads(q, batch_size) - # (batch_size, num_heads, seq_len_k, depth) k = self.split_heads(k, batch_size) - # (batch_size, num_heads, seq_len_v, depth) v = self.split_heads(v, batch_size) - # scaled_attention.shape – (batch_size, num_heads, seq_len_q, depth) - # attention_weights.shape – (batch_size, num_heads, seq_len_q, seq_len_k) scaled_attention, attention_weights = scaled_dot_product_attention( q, k, v, mask ) - - # (batch_size, seq_len_q, num_heads, depth) scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, d_model) - concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) - - # (batch_size, seq_len_q, d_model) - output = self.dense(concat_attention) - + scaled_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) + output = self.dense(scaled_attention) return output, attention_weights -def point_wise_feed_forward_network(d_model, dff): - return tf.keras.Sequential( - [ - # INNER LAYER - # – (batch_size, seq_len, dff) - tf.keras.layers.Dense(dff, activation="relu"), - # OUTPUT - # – (batch_size, seq_len, d_model) - tf.keras.layers.Dense(d_model), - ] - ) - - -class TransformerEncoderLayer(tf.keras.layers.Layer): - def __init__(self, d_model, num_heads, dff, dropout_rate=0.1): - super(TransformerEncoderLayer, self).__init__() - - self.mha = tf.keras.layers.MultiHeadAttention( - num_heads, - key_dim=d_model, - ) - self.ffn = point_wise_feed_forward_network(d_model, dff) - - self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) - self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) - - self.dropout1 = tf.keras.layers.Dropout(dropout_rate) - self.dropout2 = tf.keras.layers.Dropout(dropout_rate) - - def call(self, x, training, mask=None): - # returns (batch_size, input_seq_len, d_model) - attn_output, _ = self.mha(x, x, x, mask, return_attention_scores=True) - - # Potentially unncessary by passing dropout1 to tf.keras.layers.MultiHeadAttention (if using tf MHA) - attn_output = self.dropout1(attn_output, training=training) - - # Residual connection followed by layer normalization - # – returns (batch_size, input_seq_len, d_model) - out1 = self.layernorm1(x + attn_output, training=training) - - # Point-wise Feed Forward Step - # – returns (batch_size, input_seq_len, d_model) - ffn_output = self.ffn(out1, training=training) - ffn_output = self.dropout2(ffn_output, training=training) - - # Residual connection followed by layer normalization - # – returns (batch_size, input_seq_len, d_model) - out2 = self.layernorm2(out1 + ffn_output, training=training) - - return out2 - - -class TransformerDecoderLayer(tf.keras.layers.Layer): - def __init__(self, d_model, num_heads, dff, dropout_rate=0.1): - super(TransformerDecoderLayer, self).__init__() - - # WE COULD USE A CUSTOM DEFINED MHA MODEL BUT WE WILL USE TFA INSTEAD +class DecoderLayer(tf.keras.layers.Layer): + def __init__(self, d_model, num_heads, dff, max_len, rate=0.1): + super(DecoderLayer, self).__init__() self.mha1 = MultiHeadAttention(d_model, num_heads) self.mha2 = MultiHeadAttention(d_model, num_heads) - # - # # Multi Head Attention Layers - # self.mha1 = tf.keras.layers.MultiHeadAttention(num_heads, key_dim=d_model,) - # self.mha2 = tf.keras.layers.MultiHeadAttention(num_heads, key_dim=d_model,) - # Feed Forward NN self.ffn = point_wise_feed_forward_network(d_model, dff) - # Layer Normalization Layers self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6) - # Dropout Layers - self.dropout1 = tf.keras.layers.Dropout(dropout_rate) - self.dropout2 = tf.keras.layers.Dropout(dropout_rate) - self.dropout3 = tf.keras.layers.Dropout(dropout_rate) + self.dropout1 = tf.keras.layers.Dropout(rate) + self.dropout2 = tf.keras.layers.Dropout(rate) + self.dropout3 = tf.keras.layers.Dropout(rate) - # enc_output.shape == (batch_size, input_seq_len, d_model) - def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None): - attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) + def call( + self, + x, + enc_output, + enc_pos, + dec_pos, + training, + look_ahead_mask=None, + padding_mask=None, + ): + # (batch_size, target_seq_len, d_model) + attn1, attn_weights_block1 = self.mha1( + x, x, x, q_pos=dec_pos, k_pos=dec_pos, mask=look_ahead_mask + ) attn1 = self.dropout1(attn1, training=training) + out1 = self.layernorm1(attn1 + x) - # Residual connection followed by layer normalization - # – (batch_size, target_seq_len, d_model) - out1 = self.layernorm1(attn1 + x, training=training) - - # Merging connection between encoder and decoder (MHA) - # – (batch_size, target_seq_len, d_model) attn2, attn_weights_block2 = self.mha2( - enc_output, enc_output, out1, padding_mask + out1, enc_output, enc_output, q_pos=dec_pos, k_pos=enc_pos ) attn2 = self.dropout2(attn2, training=training) + out2 = self.layernorm2(attn2 + out1) - # Residual connection followed by layer normalization - # – (batch_size, target_seq_len, d_model) - out2 = self.layernorm2(attn2 + out1, training=training) - - # (batch_size, target_seq_len, d_model) - ffn_output = self.ffn(out2, training=training) + ffn_output = self.ffn(out2) ffn_output = self.dropout3(ffn_output, training=training) - - # Residual connection followed by layer normalization - # – (batch_size, target_seq_len, d_model) - out3 = self.layernorm3(ffn_output + out2, training=training) + out3 = self.layernorm3(ffn_output + out2) return out3, attn_weights_block1, attn_weights_block2 -class TransformerEncoder(tf.keras.layers.Layer): +class Decoder(tf.keras.Model): def __init__( - self, - num_layers, - d_model, - num_heads, - dff, - maximum_position_encoding, - dropout_rate=0.1, + self, num_layers, d_model, num_heads, dff, target_vocab_size, max_len, rate=0.1 ): - super(TransformerEncoder, self).__init__() - - self.d_model = d_model - self.num_layers = num_layers - self.embedding = tf.keras.layers.Dense(self.d_model, activation="relu") - self.pos_encoding = positional_encoding_1d( - maximum_position_encoding, self.d_model - ) - self.enc_layers = [ - TransformerEncoderLayer(d_model, num_heads, dff, dropout_rate) - for _ in range(num_layers) - ] - self.dropout = tf.keras.layers.Dropout(dropout_rate) - - def call(self, x, training, mask=None): - # adding embedding and position encoding. - # – (batch_size, input_seq_len, d_model) - x = self.embedding(x) - x += self.pos_encoding - x = self.dropout(x, training=training) - - for i in range(self.num_layers): - x = self.enc_layers[i](x, training, mask) - - # – (batch_size, input_seq_len, d_model) - return x - - -class TransformerDecoder(tf.keras.layers.Layer): - def __init__( - self, - num_layers, - d_model, - num_heads, - dff, - target_vocab_size, - maximum_position_encoding, - dropout_rate=0.1, - ): - super(TransformerDecoder, self).__init__() - + super(Decoder, self).__init__() self.d_model = d_model self.num_layers = num_layers self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model) - self.pos_encoding = positional_encoding_1d(maximum_position_encoding, d_model) + self.pos_encoding_1d = positional_encoding_1d(max_len, d_model) + self.pos_encoding_2d = positional_encoding_2d(16, 16, self.d_model) self.dec_layers = [ - TransformerDecoderLayer(d_model, num_heads, dff, dropout_rate) + DecoderLayer(d_model, num_heads, dff, max_len, rate) for _ in range(num_layers) ] - self.dropout = tf.keras.layers.Dropout(dropout_rate) + self.dropout = tf.keras.layers.Dropout(rate) + self.final_layer = tf.keras.layers.Dense(target_vocab_size) def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None): seq_len = tf.shape(x)[1] - attention_weights = {} - - # adding embedding and position encoding. - # – (batch_size, target_seq_len, d_model) - x = self.embedding(x) + dec_pos = self.pos_encoding_1d[:, :seq_len, :] + x = self.embedding(x) # (batch_size, target_seq_len, d_model) x *= tf.math.sqrt(tf.cast(self.d_model, TARGET_DTYPE)) - x += self.pos_encoding[:, :seq_len, :] - - x = self.dropout(x, training=training) for i in range(self.num_layers): x, block1, block2 = self.dec_layers[i]( - x, enc_output, training, look_ahead_mask, padding_mask + x, + enc_output, + self.pos_encoding_2d, + dec_pos, + training, + look_ahead_mask, + padding_mask, ) - attention_weights["decoder_layer{}_block1".format(i + 1)] = block1 - attention_weights["decoder_layer{}_block2".format(i + 1)] = block2 - - # x.shape == (batch_size, target_seq_len, d_model) - return x, attention_weights - - -class Transformer(tf.keras.Model): - def __init__( - self, - num_layers, - d_model, - num_heads, - dff, - target_vocab_size, - pe_input, - pe_target, - dropout_rate=0.1, - ): - super(Transformer, self).__init__() - - self.t_encoder = TransformerEncoder( - num_layers, d_model, num_heads, dff, pe_input, dropout_rate - ) - self.t_decoder = TransformerDecoder( - num_layers, - d_model, - num_heads, - dff, - target_vocab_size, - pe_target, - dropout_rate, - ) - self.t_final_layer = tf.keras.layers.Dense(target_vocab_size) - - def call( - self, - t_inp, - t_tar, - training, - enc_padding_mask=None, - look_ahead_mask=None, - dec_padding_mask=None, - ): - # (batch_size, inp_seq_len, d_model) - enc_output = self.t_encoder(t_inp, training, enc_padding_mask) - - # dec_output.shape == (batch_size, tar_seq_len, d_model) - dec_output, attention_weights = self.t_decoder( - t_tar, enc_output, training, look_ahead_mask, dec_padding_mask - ) - - # (batch_size, tar_seq_len, target_vocab_size) - final_output = self.t_final_layer(dec_output) - return final_output, attention_weights + predictions = self.final_layer(x) + return predictions diff --git a/DECIMER/config.py b/DECIMER/config.py index 6e40553..dc27089 100644 --- a/DECIMER/config.py +++ b/DECIMER/config.py @@ -1,8 +1,8 @@ # Network configuration file import tensorflow as tf import efficientnet.tfkeras as efn -import DECIMER.Efficient_Net_encoder -import DECIMER.Transformer_decoder +import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder +import DECIMER.Transformer_decoder as Transformer_decoder from PIL import Image, ImageEnhance from pillow_heif import register_heif_opener from pathlib import Path @@ -252,7 +252,7 @@ def initialize_transformer_config( transformer_d_dff, transformer_n_heads, image_embedding_dim, - dropout_rate=0.1, + rate=0.1, ): """This functions initializes the Transformer model as decoder with user defined configurations. @@ -269,13 +269,12 @@ def initialize_transformer_config( """ self.transformer_config = dict( num_layers=n_transformer_layers, - d_model=image_embedding_dim[-1], + d_model=image_embedding_dim, num_heads=transformer_n_heads, dff=transformer_d_dff, target_vocab_size=vocab_len, - pe_input=image_embedding_dim[0], - pe_target=max_len, - dropout_rate=0.1, + max_len=max_len, + rate=0.1, ) def initialize_lr_config(self, warm_steps, n_epochs): @@ -334,18 +333,9 @@ def prepare_models(encoder_config, transformer_config, replica_batch_size, verbo # Instantiate the encoder model encoder = Efficient_Net_encoder.Encoder(**encoder_config) - initialization_batch = encoder( - tf.ones( - ((replica_batch_size,) + encoder_config["image_shape"]), dtype=TARGET_DTYPE - ), - training=False, - ) # Instantiate the decoder model - transformer = Transformer_decoder.Transformer(**transformer_config) - transformer( - initialization_batch, tf.random.uniform((replica_batch_size, 1)), training=False - ) + transformer = Transformer_decoder.Decoder(**transformer_config) # Show the model architectures and plot the learning rate if verbose: diff --git a/DECIMER/efficientnetv2/utils.py b/DECIMER/efficientnetv2/utils.py index 9710c0e..2446d46 100644 --- a/DECIMER/efficientnetv2/utils.py +++ b/DECIMER/efficientnetv2/utils.py @@ -131,7 +131,6 @@ def __call__(self, step): lr = tf.math.maximum(lr, self.minimal_lr) if self.warmup_epochs: - logging.info("Learning rate warmup_epochs: %s", str(self.warmup_epochs)) warmup_steps = int(self.warmup_epochs * self.steps_per_epoch) warmup_lr = ( self.initial_lr @@ -220,9 +219,6 @@ def _moments(self, inputs, reduction_axes, keep_dims): def call(self, inputs, training=None): outputs = super().call(inputs, training) - # A temporary hack for tf1 compatibility with keras batch norm. - for u in self.updates: - tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, u) return outputs @@ -234,14 +230,6 @@ def __init__(self, **kwargs): kwargs["name"] = "tpu_batch_normalization" super().__init__(**kwargs) - def call(self, inputs, training=None): - outputs = super().call(inputs, training) - if training and not tf.executing_eagerly(): - # A temporary hack for tf1 compatibility with keras batch norm. - for u in self.updates: - tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, u) - return outputs - def normalization( norm_type: str, axis=-1, epsilon=0.001, momentum=0.99, groups=8, name=None @@ -360,7 +348,7 @@ def __init__(self, name, _): # pylint: disable=super-init-not-called def scalar(name, tensor, is_tpu=True): """Stores a (name, Tensor) tuple in a custom collection.""" - logging.info("Adding scale summary %s", Pair(name, tensor)) + logging.info("Adding scalar summary %s", Pair(name, tensor)) if is_tpu: tf.compat.v1.add_to_collection( "scalar_summaries", Pair(name, tf.reduce_mean(tensor)) @@ -421,7 +409,7 @@ def _custom_getter(getter, *args, **kwargs): yield varscope -def set_precision_policy(policy_name=None, loss_scale=False): +def set_precision_policy(policy_name=None): """Set precision policy according to the name. Args: @@ -435,15 +423,8 @@ def set_precision_policy(policy_name=None, loss_scale=False): assert policy_name in ("mixed_float16", "mixed_bfloat16", "float32") logging.info("use mixed precision policy name %s", policy_name) tf.compat.v1.keras.layers.enable_v2_dtype_behavior() - # mixed_float16 training is not supported for now, so disable loss_scale. - # float32 and mixed_bfloat16 do not need loss scale for training. - if loss_scale: - policy = tf.keras.mixed_precision.experimental.Policy(policy_name) - else: - policy = tf.keras.mixed_precision.experimental.Policy( - policy_name, loss_scale=None - ) - tf.keras.mixed_precision.experimental.set_policy(policy) + policy = tf.keras.mixed_precision.Policy(policy_name) + tf.keras.mixed_precision.set_global_policy(policy) def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs): @@ -464,6 +445,7 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs): Returns: the output of mm model. """ + del tt if pp == "mixed_bfloat16": set_precision_policy(pp) inputs = tf.cast(ii, tf.bfloat16) @@ -471,7 +453,7 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs): outputs = mm(inputs, *args, **kwargs) set_precision_policy("float32") elif pp == "mixed_float16": - set_precision_policy(pp, loss_scale=tt) + set_precision_policy(pp) inputs = tf.cast(ii, tf.float16) with float16_scope(): outputs = mm(inputs, *args, **kwargs) @@ -543,7 +525,78 @@ def get_ckpt_var_map( if not var_map or len(var_map) < 5: raise ValueError(f"var_map={var_map} is almost empty, please check logs.") - for k, v in var_map.items(): + for (k, v) in var_map.items(): logging.log_first_n(logging.INFO, f"Init {v.op.name} from ckpt var {k}", 10) return var_map + + +def restore_tf2_ckpt(model, ckpt_path_or_file, skip_mismatch=True, exclude_layers=None): + """Restore variables from a given checkpoint. + + Args: + model: the keras model to be restored. + ckpt_path_or_file: the path or file for checkpoint. + skip_mismatch: whether to skip variables if shape mismatch, + only works with tf1 checkpoint. + exclude_layers: string list exclude layer's variables, + only works with tf2 checkpoint. + + Raises: + KeyError: if access unexpected variables. + """ + ckpt_file = ckpt_path_or_file + if tf.io.gfile.isdir(ckpt_file): + ckpt_file = tf.train.latest_checkpoint(ckpt_file) + + # Try to load object-based checkpoint (by model.save_weights). + var_list = tf.train.list_variables(ckpt_file) + if var_list[0][0] == "_CHECKPOINTABLE_OBJECT_GRAPH": + print(f"Load checkpointable from {ckpt_file}, excluding {exclude_layers}") + keys = {var[0].split("/")[0] for var in var_list} + keys.discard("_CHECKPOINTABLE_OBJECT_GRAPH") + if exclude_layers: + exclude_layers = set(exclude_layers) + keys = keys.difference(exclude_layers) + ckpt = tf.train.Checkpoint( + **{ + key: getattr(model, key, None) + for key in keys + if getattr(model, key, None) + } + ) + status = ckpt.restore(ckpt_file) + status.assert_nontrivial_match() + return + + print(f"Load TF1 graph based checkpoint from {ckpt_file}.") + var_dict = {v.name.split(":")[0]: v for v in model.weights} + reader = tf.train.load_checkpoint(ckpt_file) + var_shape_map = reader.get_variable_to_shape_map() + for key, var in var_dict.items(): + if key in var_shape_map: + if var_shape_map[key] != var.shape: + msg = "Shape mismatch: %s" % key + if skip_mismatch: + logging.warning(msg) + else: + raise ValueError(msg) + else: + var.assign(reader.get_tensor(key), read_value=False) + logging.log_first_n( + logging.INFO, f"Init {var.name} from {key} ({ckpt_file})", 10 + ) + else: + msg = "Not found %s in %s" % (key, ckpt_file) + if skip_mismatch: + logging.warning(msg) + else: + raise KeyError(msg) + + +class ReuableBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore): + """A BackupAndRestore callback that can be used across multiple model.fit()s.""" + + def on_train_end(self, logs=None): + # don't delete the backup, so it can be used for future model.fit()s + pass