Skip to content

Commit

Permalink
[FIX] pipeline ensemble for multi-task classification
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapels committed Jan 15, 2025
1 parent 8bd0d43 commit 4c0c245
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 0 additions & 2 deletions src/deepmol/pipeline/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,6 @@ def predict_proba(self, dataset: Dataset, return_invalid: bool = False) -> np.nd
predictions_i = [prediction[:, i] for prediction in predictions]
final_predictions[:, i] = np.average(predictions_i, axis=0, weights=self.weights)

final_predictions = np.average(final_predictions, axis=0, weights=self.weights)

final_predictions = self._deal_with_nan(return_invalid, final_predictions)

return final_predictions
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/dataset/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_multilabel(self):

predictions = pipeline.predict_proba(self.multilabel_classification, return_invalid=True)

self.assertEqual(predictions.shape[0], 99)
self.assertEqual(predictions.shape[0], 100)



Expand Down

0 comments on commit 4c0c245

Please sign in to comment.