Skip to content

Commit

Permalink
Merge pull request #90 from Kohulan/development
Browse files Browse the repository at this point in the history
feat: improved DECIMER hand-drawn model
  • Loading branch information
Kohulan authored Mar 8, 2024
2 parents 8d3bcaf + c145044 commit b20a8ab
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 75 deletions.
6 changes: 3 additions & 3 deletions DECIMER/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""DECIMER V2.4.0 Python Package. ============================
"""DECIMER V2.6.0 Python Package. ============================
This repository contains DECIMER-V2,Deep lEarning for Chemical ImagE Recognition) project
was launched to address the OCSR problem with the latest computational intelligence methods
Expand All @@ -19,11 +19,11 @@
please raise a issue on the Github repository.
"""

__version__ = "2.4.0"
__version__ = "2.6.0"

__all__ = [
"DECIMER",
]


from .decimer import predict_SMILES, predict_SMILES_with_confidence
from .decimer import predict_SMILES
107 changes: 61 additions & 46 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,44 @@
# Set path
default_path = pystow.join("DECIMER-V2")

# download models to a default location
utils.ensure_model(default_path=default_path)

# Load important pickle files which consists the tokenizers and the maxlength setting
tokenizer = pickle.load(
open(
os.path.join(
default_path.as_posix(), "DECIMER_model", "assets", "tokenizer_SMILES.pkl"
),
"rb",
)
)
model_urls = {
"DECIMER": "https://zenodo.org/record/8300489/files/models.zip",
"DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip",
}


def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.
def get_models(model_urls: dict):
"""Download and load models from the provided URLs.
Agrs:
str: image_path
This function downloads models from the provided URLs to a default location,
then loads tokenizers and TensorFlow saved models.
Args:
model_urls (dict): A dictionary containing model names as keys and their corresponding URLs as values.
Returns:
str: predicted SMILES
tuple: A tuple containing loaded tokenizer and TensorFlow saved models.
- tokenizer (object): Tokenizer for DECIMER model.
- DECIMER_V2 (tf.saved_model): TensorFlow saved model for DECIMER.
- DECIMER_Hand_drawn (tf.saved_model): TensorFlow saved model for DECIMER HandDrawn.
"""
if len(sys.argv) != 2:
print("Usage: {} $image_path".format(sys.argv[0]))
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)
# Download models to a default location
model_paths = utils.ensure_models(default_path=default_path, model_urls=model_urls)

# Load tokenizers
tokenizer_path = os.path.join(
model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl"
)
tokenizer = pickle.load(open(tokenizer_path, "rb"))

# Load DECIMER models
DECIMER_V2 = tf.saved_model.load(model_paths["DECIMER"])
DECIMER_Hand_drawn = tf.saved_model.load(model_paths["DECIMER_HandDrawn"])

return tokenizer, DECIMER_V2, DECIMER_Hand_drawn


tokenizer, DECIMER_V2, DECIMER_Hand_drawn = get_models(model_urls)


def detokenize_output(predicted_array: int) -> str:
Expand Down Expand Up @@ -112,44 +121,50 @@ def detokenize_output_add_confidence(
return decoded_prediction_with_confidence


# Load DECIMER model_packed
DECIMER_V2 = tf.saved_model.load(default_path.as_posix() + "/DECIMER_model/")


def predict_SMILES(image_path: str) -> str:
"""This function takes an image path (str) and returns the SMILES
representation of the depicted molecule (str).
def predict_SMILES(
image_path: str, confidence: bool = False, hand_drawn: bool = False
) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.
Args:
image_path (str): Path of chemical structure depiction image
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn
Returns:
(str): SMILES representation of the molecule in the input image
str: SMILES representation of the molecule in the input image, optionally with confidence values
"""
chemical_structure = config.decode_image(image_path)
predicted_tokens, _ = DECIMER_V2(chemical_structure)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))

predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))

if confidence:
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens, confidence_values
)
return predicted_SMILES, predicted_SMILES_with_confidence

return predicted_SMILES


def predict_SMILES_with_confidence(image_path: str) -> List[Tuple[str, float]]:
"""This function takes an image path (str) and returns a list of tuples
that contain each token of the predicted SMILES string and the confidence
level from the last layer of the Transformer decoder.
def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.
Args:
image_path (str): Path of chemical structure depiction image
Agrs:
str: image_path
Returns:
(List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
values of the predicted SMILES
str: predicted SMILES
"""
decodedImage = config.decode_image(image_path)
predicted_tokens, confidence_values = DECIMER_V2(tf.constant(decodedImage))
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens, confidence_values
)
return predicted_SMILES_with_confidence
if len(sys.argv) != 2:
print("Usage: {} $image_path".format(sys.argv[0]))
else:
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)


if __name__ == "__main__":
Expand Down
45 changes: 25 additions & 20 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,32 @@ def decoder(predictions):
return modified


def ensure_model(
default_path: str,
model_url: str = "https://zenodo.org/record/8300489/files/models.zip",
):
"""Function to ensure model is present locally.
def ensure_models(default_path: str, model_urls: dict) -> dict:
"""Function to ensure models are present locally.
Convenient function to ensure model download before usage
Convenient function to ensure model downloads before usage
Args:
default path (str): default path for DECIMER data
model_url (str): trained model url for downloading
"""
default_path (str): Default path for model data
model_urls (dict): Dictionary containing model names as keys and their corresponding URLs as values
model_path = os.path.join(default_path.as_posix(), "DECIMER_model")
print(model_path)

if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
Returns:
dict: A dictionary containing model names as keys and their local paths as values
"""
model_paths = {}

for model_name, model_url in model_urls.items():
model_path = os.path.join(default_path, f"{model_name}_model")
if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)

# Store the model path
model_paths[model_name] = model_path

return model_paths
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
) and platform.system() == "Darwin":
tensorflow_os = "tensorflow-macos>=2.10.0"
else:
tensorflow_os = "tensorflow==2.12.0"
tensorflow_os = "tensorflow>=2.12.0"

with open("README.md", "r") as fh:
long_description = fh.read()
Expand Down
File renamed without changes
17 changes: 12 additions & 5 deletions Tests/test_functions.py → tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import pytest

from DECIMER import predict_SMILES
from DECIMER import predict_SMILES_with_confidence


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_imagetosmiles():
img_path = "Tests/caffeine.png"
img_path = "tests/caffeine.png"
expected_result = "CN1C=NC2=C1C(=O)N(C)C(=O)N2C"
actual_result = predict_SMILES(img_path)
assert expected_result == actual_result


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_imagetosmilesWithConfidence():
img_path = "Tests/caffeine.png"
actual_result = predict_SMILES_with_confidence(img_path)
img_path = "tests/caffeine.png"
actual_result = predict_SMILES(img_path, confidence=True)

for element, confidence in actual_result:
for element, confidence in actual_result[1]:
assert (
confidence >= 0.9
), f"Confidence for element '{element}' is below 0.9 (confidence: {confidence})"


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_imagetosmileshanddrawn():
img_path = "tests/caffeine.png"
expected_result = "CN1C=NC2=C1C(=O)N(C)C(=O)N2C"
actual_result = predict_SMILES(img_path, hand_drawn=True)
assert expected_result == actual_result

0 comments on commit b20a8ab

Please sign in to comment.