diff --git a/rmgpy/ml/estimator.py b/rmgpy/ml/estimator.py index 547cbe25fc1..8e120dd8e8a 100644 --- a/rmgpy/ml/estimator.py +++ b/rmgpy/ml/estimator.py @@ -30,11 +30,11 @@ from argparse import Namespace from typing import Callable, Union +try: + import chemprop +except ImportError as chemprop_exception: + chemprop = None import numpy as np -from chemprop.data import MoleculeDatapoint, MoleculeDataset -from chemprop.parsing import update_checkpoint_args -from chemprop.train import predict -from chemprop.utils import load_args, load_checkpoint, load_scalers from rmgpy.molecule import Molecule from rmgpy.species import Species @@ -112,16 +112,21 @@ def load_estimator(model_dir: str) -> Callable[[str], np.ndarray]: """ Load chemprop model and return function for evaluating it. """ + if chemprop is None: + # Delay chemprop ImportError until we actually try to use it + # so that RMG can load successfully without chemprop. + raise chemprop_exception + args = Namespace() # Simple class to hold attributes # Set up chemprop predict arguments args.checkpoint_dir = model_dir args.checkpoint_path = None - update_checkpoint_args(args) + chemprop.parsing.update_checkpoint_args(args) args.cuda = False - scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) - train_args = load_args(args.checkpoint_paths[0]) + scaler, features_scaler = chemprop.utils.load_scalers(args.checkpoint_paths[0]) + train_args = chemprop.utils.load_args(args.checkpoint_paths[0]) # Update args with training arguments for key, value in vars(train_args).items(): @@ -131,12 +136,14 @@ def load_estimator(model_dir: str) -> Callable[[str], np.ndarray]: # Load models in ensemble models = [] for checkpoint_path in args.checkpoint_paths: - models.append(load_checkpoint(checkpoint_path, cuda=args.cuda)) + models.append(chemprop.utils.load_checkpoint(checkpoint_path, cuda=args.cuda)) # Set up estimator def estimator(smi: str): # Make dataset - data = MoleculeDataset([MoleculeDatapoint(line=[smi], args=args)]) + data = chemprop.data.MoleculeDataset( + [chemprop.data.MoleculeDatapoint(line=[smi], args=args)] + ) # Normalize features if train_args.features_scaling: @@ -148,7 +155,7 @@ def estimator(smi: str): # Predict with each model individually and sum predictions sum_preds = np.zeros((len(data), args.num_tasks)) for model in models: - model_preds = predict( + model_preds = chemprop.train.predict( model=model, data=data, batch_size=1, # We'll only predict one molecule at a time