Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved DECIMER hand-drawn model #90

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DECIMER/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""DECIMER V2.4.0 Python Package. ============================
"""DECIMER V2.6.0 Python Package. ============================

This repository contains DECIMER-V2,Deep lEarning for Chemical ImagE Recognition) project
was launched to address the OCSR problem with the latest computational intelligence methods
Expand All @@ -19,11 +19,11 @@
please raise a issue on the Github repository.
"""

__version__ = "2.4.0"
__version__ = "2.6.0"

__all__ = [
"DECIMER",
]


from .decimer import predict_SMILES, predict_SMILES_with_confidence
from .decimer import predict_SMILES
107 changes: 61 additions & 46 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,44 @@
# Set path
default_path = pystow.join("DECIMER-V2")

# download models to a default location
utils.ensure_model(default_path=default_path)

# Load important pickle files which consists the tokenizers and the maxlength setting
tokenizer = pickle.load(
open(
os.path.join(
default_path.as_posix(), "DECIMER_model", "assets", "tokenizer_SMILES.pkl"
),
"rb",
)
)
model_urls = {
"DECIMER": "https://zenodo.org/record/8300489/files/models.zip",
"DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip",
}


def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.
def get_models(model_urls: dict):
"""Download and load models from the provided URLs.

Agrs:
str: image_path
This function downloads models from the provided URLs to a default location,
then loads tokenizers and TensorFlow saved models.

Args:
model_urls (dict): A dictionary containing model names as keys and their corresponding URLs as values.

Returns:
str: predicted SMILES
tuple: A tuple containing loaded tokenizer and TensorFlow saved models.
- tokenizer (object): Tokenizer for DECIMER model.
- DECIMER_V2 (tf.saved_model): TensorFlow saved model for DECIMER.
- DECIMER_Hand_drawn (tf.saved_model): TensorFlow saved model for DECIMER HandDrawn.
"""
if len(sys.argv) != 2:
print("Usage: {} $image_path".format(sys.argv[0]))
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)
# Download models to a default location
model_paths = utils.ensure_models(default_path=default_path, model_urls=model_urls)

# Load tokenizers
tokenizer_path = os.path.join(
model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl"
)
tokenizer = pickle.load(open(tokenizer_path, "rb"))

# Load DECIMER models
DECIMER_V2 = tf.saved_model.load(model_paths["DECIMER"])
DECIMER_Hand_drawn = tf.saved_model.load(model_paths["DECIMER_HandDrawn"])

return tokenizer, DECIMER_V2, DECIMER_Hand_drawn


tokenizer, DECIMER_V2, DECIMER_Hand_drawn = get_models(model_urls)


def detokenize_output(predicted_array: int) -> str:
Expand Down Expand Up @@ -112,44 +121,50 @@ def detokenize_output_add_confidence(
return decoded_prediction_with_confidence


# Load DECIMER model_packed
DECIMER_V2 = tf.saved_model.load(default_path.as_posix() + "/DECIMER_model/")


def predict_SMILES(image_path: str) -> str:
"""This function takes an image path (str) and returns the SMILES
representation of the depicted molecule (str).
def predict_SMILES(
image_path: str, confidence: bool = False, hand_drawn: bool = False
) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.

Args:
image_path (str): Path of chemical structure depiction image
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn

Returns:
(str): SMILES representation of the molecule in the input image
str: SMILES representation of the molecule in the input image, optionally with confidence values
"""
chemical_structure = config.decode_image(image_path)
predicted_tokens, _ = DECIMER_V2(chemical_structure)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))

predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))

if confidence:
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens, confidence_values
)
return predicted_SMILES, predicted_SMILES_with_confidence

return predicted_SMILES


def predict_SMILES_with_confidence(image_path: str) -> List[Tuple[str, float]]:
"""This function takes an image path (str) and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
level from the last layer of the Transformer decoder.
def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.

Args:
image_path (str): Path of chemical structure depiction image
Agrs:
str: image_path

Returns:
(List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
values of the predicted SMILES
str: predicted SMILES
"""
decodedImage = config.decode_image(image_path)
predicted_tokens, confidence_values = DECIMER_V2(tf.constant(decodedImage))
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens, confidence_values
)
return predicted_SMILES_with_confidence
if len(sys.argv) != 2:
print("Usage: {} $image_path".format(sys.argv[0]))
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)


if __name__ == "__main__":
Expand Down
45 changes: 25 additions & 20 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,32 @@ def decoder(predictions):
return modified


def ensure_model(
default_path: str,
model_url: str = "https://zenodo.org/record/8300489/files/models.zip",
):
"""Function to ensure model is present locally.
def ensure_models(default_path: str, model_urls: dict) -> dict:
"""Function to ensure models are present locally.

Convenient function to ensure model download before usage
Convenient function to ensure model downloads before usage

Args:
default path (str): default path for DECIMER data
model_url (str): trained model url for downloading
"""
default_path (str): Default path for model data
model_urls (dict): Dictionary containing model names as keys and their corresponding URLs as values

model_path = os.path.join(default_path.as_posix(), "DECIMER_model")
print(model_path)

if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
Returns:
dict: A dictionary containing model names as keys and their local paths as values
"""
model_paths = {}

for model_name, model_url in model_urls.items():
model_path = os.path.join(default_path, f"{model_name}_model")
if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)

# Store the model path
model_paths[model_name] = model_path

return model_paths
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
) and platform.system() == "Darwin":
tensorflow_os = "tensorflow-macos>=2.10.0"
else:
tensorflow_os = "tensorflow==2.12.0"
tensorflow_os = "tensorflow>=2.12.0"

with open("README.md", "r") as fh:
long_description = fh.read()
Expand Down
File renamed without changes
17 changes: 12 additions & 5 deletions Tests/test_functions.py → tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import pytest

from DECIMER import predict_SMILES
from DECIMER import predict_SMILES_with_confidence


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_imagetosmiles():
img_path = "Tests/caffeine.png"
img_path = "tests/caffeine.png"
expected_result = "CN1C=NC2=C1C(=O)N(C)C(=O)N2C"
actual_result = predict_SMILES(img_path)
assert expected_result == actual_result


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_imagetosmilesWithConfidence():
img_path = "Tests/caffeine.png"
actual_result = predict_SMILES_with_confidence(img_path)
img_path = "tests/caffeine.png"
actual_result = predict_SMILES(img_path, confidence=True)

for element, confidence in actual_result:
for element, confidence in actual_result[1]:
assert (
confidence >= 0.9
), f"Confidence for element '{element}' is below 0.9 (confidence: {confidence})"


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_imagetosmileshanddrawn():
img_path = "tests/caffeine.png"
expected_result = "CN1C=NC2=C1C(=O)N(C)C(=O)N2C"
actual_result = predict_SMILES(img_path, hand_drawn=True)
assert expected_result == actual_result
Loading