diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index e720e029..2ea80a10 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -143,8 +143,7 @@ def _predict( test_data = build_dataloader(X, num_workers=self.n_jobs, shuffle=False) predictions = self.lightning_trainer.predict(self.model, test_data) prediction_array = np.vstack(predictions) # type: ignore - prediction_array = prediction_array.squeeze() - + prediction_array = np.atleast_1d(prediction_array.squeeze()) # Check if the predictions have the same length as the input dataset if prediction_array.shape[0] != len(X): raise AssertionError(