diff --git a/bluecast/evaluation/eval_metrics.py b/bluecast/evaluation/eval_metrics.py index cac9bb3f..37ef83d4 100644 --- a/bluecast/evaluation/eval_metrics.py +++ b/bluecast/evaluation/eval_metrics.py @@ -120,9 +120,11 @@ def plot_probability_distribution( # Ensure probs is a 2D array if probs.ndim == 1: - probs = np.column_stack( - (probs, 1 - probs) - ) # Create a 2D array with (probs, 1 - probs) + # Convert 1D binary probabilities to 2D + probs = np.column_stack((probs, 1 - probs)) + elif probs.ndim == 2 and probs.shape[1] == 1: + # Handle the case where probs is (n_samples, 1) by converting it to (n_samples, 2) + probs = np.column_stack((probs[:, 0], 1 - probs[:, 0])) unique_classes = np.unique(y_classes) colors = plt.get_cmap("tab10") # Get a colormap