Skip to content

Commit

Permalink
feat: improved DECIMER hand-drawn model
Browse files Browse the repository at this point in the history
  • Loading branch information
Kohulan committed Mar 5, 2024
1 parent 8d3bcaf commit 95d6049
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 68 deletions.
101 changes: 52 additions & 49 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,40 @@
# Set path
default_path = pystow.join("DECIMER-V2")

# download models to a default location
utils.ensure_model(default_path=default_path)

# Load important pickle files which consists the tokenizers and the maxlength setting
tokenizer = pickle.load(
open(
os.path.join(
default_path.as_posix(), "DECIMER_model", "assets", "tokenizer_SMILES.pkl"
),
"rb",
)
)
model_urls = {
"DECIMER": "https://zenodo.org/record/8300489/files/models.zip",
"DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip"
}

def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model]:
"""Download and load models from the provided URLs.
def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.
This function downloads models from the provided URLs to a default location,
then loads tokenizers and TensorFlow saved models.
Agrs:
str: image_path
Args:
model_urls (dict): A dictionary containing model names as keys and their corresponding URLs as values.
Returns:
str: predicted SMILES
tuple: A tuple containing loaded tokenizer and TensorFlow saved models.
- tokenizer (object): Tokenizer for DECIMER model.
- DECIMER_V2 (tf.saved_model): TensorFlow saved model for DECIMER.
- DECIMER_Hand_drawn (tf.saved_model): TensorFlow saved model for DECIMER HandDrawn.
"""
if len(sys.argv) != 2:
print("Usage: {} $image_path".format(sys.argv[0]))
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)
# Download models to a default location
model_paths = utils.ensure_models(default_path=default_path, model_urls=model_urls)

# Load tokenizers
tokenizer_path = os.path.join(model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl")
tokenizer = pickle.load(open(tokenizer_path, "rb"))

# Load DECIMER models
DECIMER_V2 = tf.saved_model.load(model_paths["DECIMER"])
DECIMER_Hand_drawn = tf.saved_model.load(model_paths["DECIMER_HandDrawn"])

return tokenizer, DECIMER_V2, DECIMER_Hand_drawn

tokenizer, DECIMER_V2, DECIMER_Hand_drawn = get_models(model_urls)

def detokenize_output(predicted_array: int) -> str:
"""This function takes the predited tokens from the DECIMER model and
Expand Down Expand Up @@ -111,46 +115,45 @@ def detokenize_output_add_confidence(
decoded_prediction_with_confidence.append(prediction_with_confidence_[-1])
return decoded_prediction_with_confidence


# Load DECIMER model_packed
DECIMER_V2 = tf.saved_model.load(default_path.as_posix() + "/DECIMER_model/")


def predict_SMILES(image_path: str) -> str:
"""This function takes an image path (str) and returns the SMILES
representation of the depicted molecule (str).
def predict_SMILES(image_path: str, confidence: bool = False, hand_drawn: bool = False) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.
Args:
image_path (str): Path of chemical structure depiction image
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn
Returns:
(str): SMILES representation of the molecule in the input image
str: SMILES representation of the molecule in the input image, optionally with confidence values
"""
chemical_structure = config.decode_image(image_path)
predicted_tokens, _ = DECIMER_V2(chemical_structure)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))

predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))

if confidence:
predicted_SMILES_with_confidence = detokenize_output_add_confidence(predicted_tokens, confidence_values)
return predicted_SMILES, predicted_SMILES_with_confidence

return predicted_SMILES

def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.
def predict_SMILES_with_confidence(image_path: str) -> List[Tuple[str, float]]:
"""This function takes an image path (str) and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
level from the last layer of the Transformer decoder.
Args:
image_path (str): Path of chemical structure depiction image
Agrs:
str: image_path
Returns:
(List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
values of the predicted SMILES
str: predicted SMILES
"""
decodedImage = config.decode_image(image_path)
predicted_tokens, confidence_values = DECIMER_V2(tf.constant(decodedImage))
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens, confidence_values
)
return predicted_SMILES_with_confidence

if len(sys.argv) != 2:
print("Usage: {} $image_path".format(sys.argv[0]))
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)

if __name__ == "__main__":
main()
46 changes: 27 additions & 19 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,36 @@ def decoder(predictions):
)
return modified


def ensure_model(
def ensure_models(
default_path: str,
model_url: str = "https://zenodo.org/record/8300489/files/models.zip",
):
"""Function to ensure model is present locally.
model_urls: dict
) -> dict:
"""Function to ensure models are present locally.
Convenient function to ensure model download before usage
Convenient function to ensure model downloads before usage
Args:
default path (str): default path for DECIMER data
model_url (str): trained model url for downloading
default_path (str): Default path for model data
model_urls (dict): Dictionary containing model names as keys and their corresponding URLs as values
Returns:
dict: A dictionary containing model names as keys and their local paths as values
"""
model_paths = {}

for model_name, model_url in model_urls.items():
model_path = os.path.join(default_path, f"{model_name}_model")
if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)

# Store the model path
model_paths[model_name] = model_path

return model_paths

model_path = os.path.join(default_path.as_posix(), "DECIMER_model")
print(model_path)

if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
File renamed without changes
File renamed without changes.

0 comments on commit 95d6049

Please sign in to comment.