Skip to content

Commit

Permalink
fixup: enable importing estimator.py even when chemprop is not available
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Grambow committed Aug 19, 2019
1 parent c29696e commit 1470290
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions rmgpy/ml/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from argparse import Namespace
from typing import Callable, Union

try:
import chemprop
except ImportError as chemprop_exception:
chemprop = None
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset
from chemprop.parsing import update_checkpoint_args
from chemprop.train import predict
from chemprop.utils import load_args, load_checkpoint, load_scalers

from rmgpy.molecule import Molecule
from rmgpy.species import Species
Expand Down Expand Up @@ -112,16 +112,21 @@ def load_estimator(model_dir: str) -> Callable[[str], np.ndarray]:
"""
Load chemprop model and return function for evaluating it.
"""
if chemprop is None:
# Delay chemprop ImportError until we actually try to use it
# so that RMG can load successfully without chemprop.
raise chemprop_exception

args = Namespace() # Simple class to hold attributes

# Set up chemprop predict arguments
args.checkpoint_dir = model_dir
args.checkpoint_path = None
update_checkpoint_args(args)
chemprop.parsing.update_checkpoint_args(args)
args.cuda = False

scaler, features_scaler = load_scalers(args.checkpoint_paths[0])
train_args = load_args(args.checkpoint_paths[0])
scaler, features_scaler = chemprop.utils.load_scalers(args.checkpoint_paths[0])
train_args = chemprop.utils.load_args(args.checkpoint_paths[0])

# Update args with training arguments
for key, value in vars(train_args).items():
Expand All @@ -131,12 +136,14 @@ def load_estimator(model_dir: str) -> Callable[[str], np.ndarray]:
# Load models in ensemble
models = []
for checkpoint_path in args.checkpoint_paths:
models.append(load_checkpoint(checkpoint_path, cuda=args.cuda))
models.append(chemprop.utils.load_checkpoint(checkpoint_path, cuda=args.cuda))

# Set up estimator
def estimator(smi: str):
# Make dataset
data = MoleculeDataset([MoleculeDatapoint(line=[smi], args=args)])
data = chemprop.data.MoleculeDataset(
[chemprop.data.MoleculeDatapoint(line=[smi], args=args)]
)

# Normalize features
if train_args.features_scaling:
Expand All @@ -148,7 +155,7 @@ def estimator(smi: str):
# Predict with each model individually and sum predictions
sum_preds = np.zeros((len(data), args.num_tasks))
for model in models:
model_preds = predict(
model_preds = chemprop.train.predict(
model=model,
data=data,
batch_size=1, # We'll only predict one molecule at a time
Expand Down

0 comments on commit 1470290

Please sign in to comment.