Skip to content

Commit

Permalink
add token-level confidence scores
Browse files Browse the repository at this point in the history
  • Loading branch information
OBrink committed Aug 25, 2023
1 parent d572ded commit 491c7bc
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 25 deletions.
88 changes: 78 additions & 10 deletions DECIMER/Predictor_usingCheckpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import sys
import tensorflow as tf

from typing import List, Tuple
import pickle
import pystow
from selfies import decoder
import Transformer_decoder

if int(tf.__version__.split(".")[1]) <= 10:
Expand Down Expand Up @@ -116,6 +115,9 @@ def main():
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)
SMILES_with_confidence = predict_SMILES_with_confidence(sys.argv[1])
for tup in SMILES_with_confidence:
print(tup)


class DECIMER_Predictor(tf.Module):
Expand All @@ -126,6 +128,16 @@ def __init__(self, encoder, tokenizer, transformer, max_length):
self.max_length = max_length

def __call__(self, Decoded_image):
"""
Run the DECIMER predictor model when called.
Usage of predict_SMILES or predict_SMILES_with_confidence is recommended instead
Args:
Decoded_image (_type_): output of config.decode_image
Returns:
Tuple[tf.Tensor, tf.Tensor]: predicted tokens, confidence values
"""
assert isinstance(Decoded_image, tf.Tensor)

_image_batch = tf.expand_dims(Decoded_image, 0)
Expand All @@ -140,6 +152,7 @@ def __call__(self, Decoded_image):

output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
output_array = output_array.write(0, start_token)
confidence_array = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

for t in tf.range(max_length):
output = tf.transpose(output_array.stack())
Expand All @@ -154,31 +167,87 @@ def __call__(self, Decoded_image):
predictions = prediction_batch[:, -1:, :] # (batch_size, 1, vocab_size)

predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

confidence = predictions[-1][-1][int(predicted_id)]
output_array = output_array.write(t + 1, predicted_id[0])

confidence_array = confidence_array.write(t + 1, confidence)
if predicted_id == end_token:
break
output = tf.transpose(output_array.stack())
return output

return output, confidence_array.stack()


def detokenize_output(
predicted_array: tf.Tensor
) -> str:
"""
This function takes the predicted array of tokens and returns the predicted SMILES
string.
def detokenize_output(predicted_array):
Args:
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
Returns:
str: SMILES string
"""
outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()]
prediction = (
"".join([str(elem) for elem in outputs])
.replace("<start>", "")
.replace("<end>", "")
)

return prediction

def detokenize_output_add_confidence(
predicted_array: tf.Tensor,
confidence_array: tf.Tensor,
) -> List[Tuple[str, float]]:
"""
This function takes the predicted array of tokens as well as the confidence values
returned by the Transformer Decoder and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
value.
Args:
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
Returns:
str: SMILES string
"""
prediction_with_confidence = [(tokenizer.index_word[predicted_array[0].numpy()[i]],
confidence_array[i].numpy())
for i in range(len(confidence_array))]
decoded_prediction_with_confidence = list([(utils.decoder(tok), conf) for tok, conf
in prediction_with_confidence[1:-1]])
decoded_prediction_with_confidence.append(prediction_with_confidence[-1])
return decoded_prediction_with_confidence


# Initiate the DECIMER class
DECIMER = DECIMER_Predictor(encoder, tokenizer, transformer, MAX_LEN)

def predict_SMILES_with_confidence(image_path: str) -> List[Tuple[str, float]]:
"""
This function takes an image path (str) and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
level from the last layer of the Transformer decoder.
def predict_SMILES(image_path: str):
Args:
image_path (str): Path of chemical structure depiction image
Returns:
(List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
values of the predicted SMILES
"""
decodedImage = config.decode_image(image_path)
predicted_tokens, confidence_values = DECIMER(tf.constant(decodedImage))
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens,
confidence_values)
return predicted_SMILES_with_confidence


def predict_SMILES(image_path: str) -> str:
"""
This function takes an image path (str) and returns the SMILES
representation of the depicted molecule (str).
Expand All @@ -190,9 +259,8 @@ def predict_SMILES(image_path: str):
(str): SMILES representation of the molecule in the input image
"""
decodedImage = config.decode_image(image_path)
predicted_tokens = DECIMER(tf.constant(decodedImage))
predicted_tokens, _ = DECIMER(tf.constant(decodedImage))
predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))

return predicted_SMILES


Expand Down
53 changes: 43 additions & 10 deletions DECIMER/Repack_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import tensorflow as tf

from typing import List, Tuple
import pickle
import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder
import DECIMER.Transformer_decoder as Transformer_decoder
import DECIMER.config as config
import DECIMER.utils as utils

print(tf.__version__)

Expand Down Expand Up @@ -84,16 +83,51 @@
start_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])


def detokenize_output(predicted_array):
def detokenize_output(
predicted_array: tf.Tensor
) -> str:
"""
This function takes the predicted array of tokens and returns the predicted SMILES
string.
Args:
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
Returns:
str: SMILES string
"""
outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()]
prediction = (
"".join([str(elem) for elem in outputs])
.replace("<start>", "")
.replace("<end>", "")
)

return prediction

def detokenize_output_add_confidence(
predicted_array: tf.Tensor,
confidence_array: tf.Tensor,
) -> List[Tuple[str, float]]:
"""
This function takes the predicted array of tokens as well as the confidence values
returned by the Transformer Decoder and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
value.
Args:
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
Returns:
str: SMILES string
"""
prediction_with_confidence = [(tokenizer.index_word[predicted_array[0].numpy()[i]],
confidence_array[i].numpy())
for i in range(len(confidence_array))]
decoded_prediction_with_confidence = list([(utils.decoder(tok), conf) for tok, conf
in prediction_with_confidence[1:-1]])
decoded_prediction_with_confidence.append(prediction_with_confidence[-1])
return decoded_prediction_with_confidence


class DECIMER_Predictor(tf.Module):
"""This is a class which takes care of inference. It loads the saved checkpoint and the necessary
Expand Down Expand Up @@ -128,8 +162,6 @@ def __call__(self, Decoded_image):
output (tf.Tensor[tf.int64]): predicted output as an array.
"""
assert isinstance(Decoded_image, tf.Tensor)
if len(Decoded_image.shape) == 0:
sentence = Decoded_image[tf.newaxis]

_image_batch = tf.expand_dims(Decoded_image, 0)
_image_embedding = encoder(_image_batch, training=False)
Expand All @@ -143,6 +175,7 @@ def __call__(self, Decoded_image):

output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
output_array = output_array.write(0, start_token)
confidence_array = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

for t in tf.range(max_length):
output = tf.transpose(output_array.stack())
Expand All @@ -155,16 +188,16 @@ def __call__(self, Decoded_image):

# select the last word from the seq_len dimension
predictions = prediction_batch[:, -1:, :] # (batch_size, 1, vocab_size)

predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

confidence = predictions[-1][-1][int(predicted_id)]
output_array = output_array.write(t + 1, predicted_id[0])

confidence_array = confidence_array.write(t + 1, confidence)
if predicted_id == end_token:
break
output = tf.transpose(output_array.stack())

return output
return output, confidence_array.stack()


DECIMER = DECIMER_Predictor(encoder, tokenizer, transformer, MAX_LEN)
Expand Down
2 changes: 1 addition & 1 deletion DECIMER/Transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
for _ in range(num_layers)
]
self.dropout = tf.keras.layers.Dropout(rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
self.final_layer = tf.keras.layers.Dense(target_vocab_size, activation="softmax")

def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None):
seq_len = tf.shape(x)[1]
Expand Down
53 changes: 49 additions & 4 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import pystow
import shutil
import tensorflow as tf
from typing import List, Tuple
import DECIMER.config as config
import DECIMER.utils as utils

# Silence tensorflow model loading warnings.
logging.getLogger("absl").setLevel("ERROR")

# Silence tensorflow errors. optional not recommened if your model is not working properly.
# Silence tensorflow errors - not recommended if your model is not working properly.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Set the absolute path
Expand Down Expand Up @@ -87,10 +88,34 @@ def detokenize_output(predicted_array: int) -> str:
.replace("<start>", "")
.replace("<end>", "")
)

return prediction


def detokenize_output_add_confidence(
predicted_array: tf.Tensor,
confidence_array: tf.Tensor,
) -> List[Tuple[str, float]]:
"""
This function takes the predicted array of tokens as well as the confidence values
returned by the Transformer Decoder and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
value.
Args:
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
Returns:
str: SMILES string
"""
prediction_with_confidence = [(tokenizer.index_word[predicted_array[0].numpy()[i]],
confidence_array[i].numpy())
for i in range(len(confidence_array))]
decoded_prediction_with_confidence = list([(utils.decoder(tok), conf) for tok, conf
in prediction_with_confidence[1:-1]])
decoded_prediction_with_confidence.append(prediction_with_confidence[-1])
return decoded_prediction_with_confidence


# Load DECIMER model_packed
DECIMER_V2 = tf.saved_model.load(default_path.as_posix() + "/DECIMER_model/")

Expand All @@ -107,11 +132,31 @@ def predict_SMILES(image_path: str) -> str:
(str): SMILES representation of the molecule in the input image
"""
chemical_structure = config.decode_image(image_path)
predicted_tokens = DECIMER_V2(chemical_structure)
predicted_tokens, _ = DECIMER_V2(chemical_structure)
predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))

return predicted_SMILES


def predict_SMILES_with_confidence(image_path: str) -> List[Tuple[str, float]]:
"""
This function takes an image path (str) and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
level from the last layer of the Transformer decoder.
Args:
image_path (str): Path of chemical structure depiction image
Returns:
(List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
values of the predicted SMILES
"""
decodedImage = config.decode_image(image_path)
predicted_tokens, confidence_values = DECIMER_V2(tf.constant(decodedImage))
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens,
confidence_values)
return predicted_SMILES_with_confidence


if __name__ == "__main__":
main()

0 comments on commit 491c7bc

Please sign in to comment.