From 7cdffa2380f6ad48b4f2fe24f8227c96de79481f Mon Sep 17 00:00:00 2001 From: Mariya Lysenkova Wiklander Date: Fri, 22 Nov 2024 21:17:26 +0100 Subject: [PATCH] Add stats for sets of size 2 or more --- src/conformist/alpha_selector.py | 41 +++++++++++++++++++--------- src/conformist/performance_report.py | 10 +++++-- src/conformist/prediction_dataset.py | 1 - src/conformist/validation_run.py | 4 +++ src/conformist/validation_trial.py | 6 ++++ 5 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/conformist/alpha_selector.py b/src/conformist/alpha_selector.py index 2b49a1a..45a5779 100644 --- a/src/conformist/alpha_selector.py +++ b/src/conformist/alpha_selector.py @@ -8,6 +8,11 @@ class AlphaSelector(OutputDir): + FIGURE_FONTSIZE = 12 + FIGURE_WIDTH = 12 + FIGURE_HEIGHT = 8 + plt.rcParams.update({'font.size': FIGURE_FONTSIZE}) + def __init__(self, prediction_dataset: PredictionDataset, cop_class, @@ -32,6 +37,7 @@ def __init__(self, self.pcts_empty_sets = [] self.pcts_singleton_sets = [] self.pcts_singleton_or_duo_sets = [] + self.pcts_duo_plus_sets = [] self.pcts_trio_plus_sets = [] self.mean_false_negative_rates = [] self.mean_softmax_threshold = [] @@ -54,6 +60,7 @@ def run(self): self.pcts_singleton_sets.append(trial.pct_singleton_sets()) self.pcts_singleton_or_duo_sets.append( trial.pct_singleton_or_duo_sets()) + self.pcts_duo_plus_sets.append(trial.pct_duo_plus_sets()) self.pcts_trio_plus_sets.append(trial.pct_trio_plus_sets()) self.mean_false_negative_rates.append( trial.mean_false_negative_rate()) @@ -69,7 +76,9 @@ def run_reports(self): def visualize(self): # MEAN SET SIZES GRAPH - plt.figure() + plt.figure(figsize=(self.FIGURE_WIDTH, + self.FIGURE_HEIGHT)) + plt.tight_layout() data = pd.DataFrame({ 'Alpha': self.alphas, @@ -80,7 +89,10 @@ def visualize(self): plt.savefig(f'{self.output_dir}/alpha_to_mean_set_size.png') # PERCENT EMPTY/SINGLETON SETS GRAPH - plt.figure() + # MEAN SET SIZES GRAPH + plt.figure(figsize=(self.FIGURE_WIDTH, + self.FIGURE_HEIGHT)) + plt.tight_layout() # Labels x_label = 'Alpha' @@ -90,9 +102,9 @@ def visualize(self): # Create a DataFrame for the pct_empty_sets and pct_singleton_sets data = pd.DataFrame({ x_label: self.alphas, - 'n = 0': self.pcts_empty_sets, - 'n ∈ {1, 2}': self.pcts_singleton_or_duo_sets, - 'n ≥ 3': self.pcts_trio_plus_sets + 'empty (n = 0)': self.pcts_empty_sets, + 'certain (n=1)': self.pcts_singleton_or_duo_sets, + 'uncertain (n ≥ 2)': self.pcts_duo_plus_sets }) # Melt the DataFrame to have the set types as a separate column @@ -110,14 +122,16 @@ def visualize(self): # Get the current x-tick labels labels = [item.get_text() for item in plt.gca().get_xticklabels()] + target = 'certain (n=1)' + # Draw a horizontal line across the top of the highest orange bar - max_singleton_or_duo_sets = data['n ∈ {1, 2}'].max() - plt.axhline(y=max_singleton_or_duo_sets, + optimal_value = data[target].max() + plt.axhline(y=optimal_value, color='#cccccc', linestyle='--') # Get the index of the label with the highest value - idx = data['n ∈ {1, 2}'].idxmax() + idx = data[target].idxmax() # Make this label bold labels[idx] = f'$\\bf{{{labels[idx]}}}$' @@ -131,7 +145,9 @@ def visualize(self): plt.savefig(f'{self.output_dir}/alpha_to_set_sizes.png') def visualize_lambdas(self): - plt.figure() + plt.figure(figsize=(self.FIGURE_WIDTH, + self.FIGURE_HEIGHT)) + plt.tight_layout() # Only use reasonable alphas alphas = [0.05, 0.1, 0.15, 0.2, 0.3, 0.4] @@ -156,7 +172,7 @@ def visualize_lambdas(self): plt.text(self.lamhats[a], 0 + padding, f'{self.lamhats[a]:.2f}', ha='center', va='bottom', - fontsize=8, color='black', + color='black', weight='bold') i += 1 @@ -179,9 +195,8 @@ def save_summary(self): 'alpha': self.alphas, 'Mean set size': self.mean_set_sizes, '% sets n=0': self.pcts_empty_sets, - '% sets n={1}': self.pcts_singleton_sets, - '% sets n={1|2}': self.pcts_singleton_or_duo_sets, - '% sets n>=3': self.pcts_trio_plus_sets, + '% sets n=1': self.pcts_singleton_sets, + '% sets n>=2': self.pcts_duo_plus_sets, 'Mean FNR': self.mean_false_negative_rates, 'Mean softmax threshold': self.mean_softmax_threshold } diff --git a/src/conformist/performance_report.py b/src/conformist/performance_report.py index e751211..18ebc87 100644 --- a/src/conformist/performance_report.py +++ b/src/conformist/performance_report.py @@ -28,11 +28,17 @@ def pct_singleton_or_duo_sets(prediction_sets): prediction_set in prediction_sets) / \ len(prediction_sets) - def pct_trio_plus_sets(prediction_sets): - return sum(sum(prediction_set) >= 3 for + def _pct_sets_of_min_size(prediction_sets, min_size): + return sum(sum(prediction_set) >= min_size for prediction_set in prediction_sets) / \ len(prediction_sets) + def pct_duo_plus_sets(prediction_sets): + return PerformanceReport._pct_sets_of_min_size(prediction_sets, 2) + + def pct_trio_plus_sets(prediction_sets): + return PerformanceReport._pct_sets_of_min_size(prediction_sets, 3) + def _class_report(self, items_by_class, output_file_prefix, diff --git a/src/conformist/prediction_dataset.py b/src/conformist/prediction_dataset.py index 791c1e8..dedc8ac 100644 --- a/src/conformist/prediction_dataset.py +++ b/src/conformist/prediction_dataset.py @@ -21,7 +21,6 @@ class PredictionDataset(OutputDir): FIGURE_WIDTH = 12 plt.rcParams.update({'font.size': FIGURE_FONTSIZE}) - def __init__(self, df=None, predictions_csv=None, diff --git a/src/conformist/validation_run.py b/src/conformist/validation_run.py index 8bb2d38..60e9664 100644 --- a/src/conformist/validation_run.py +++ b/src/conformist/validation_run.py @@ -53,6 +53,9 @@ def pct_singleton_sets(self): def pct_singleton_or_duo_sets(self): return PerformanceReport.pct_singleton_or_duo_sets(self.prediction_sets) + def pct_duo_plus_sets(self): + return PerformanceReport.pct_duo_plus_sets(self.prediction_sets) + def pct_trio_plus_sets(self): return PerformanceReport.pct_trio_plus_sets(self.prediction_sets) @@ -130,6 +133,7 @@ def run_reports(self, base_output_dir): 'pct_empty_sets': self.pct_empty_sets(), 'pct_singleton_sets': self.pct_singleton_sets(), 'pct_singleton_or_duo_sets': self.pct_singleton_or_duo_sets(), + 'pct_duo_plus_sets': self.pct_duo_plus_sets(), 'pct_trio_plus_sets': self.pct_trio_plus_sets() }, index=[0]) diff --git a/src/conformist/validation_trial.py b/src/conformist/validation_trial.py index 11f6538..0e4531e 100644 --- a/src/conformist/validation_trial.py +++ b/src/conformist/validation_trial.py @@ -36,6 +36,12 @@ def pct_singleton_or_duo_sets(self): singleton_or_duo.append(run.pct_singleton_or_duo_sets()) return statistics.mean(singleton_or_duo) + def pct_duo_plus_sets(self): + duo_plus = [] + for run in self.runs: + duo_plus.append(run.pct_duo_plus_sets()) + return statistics.mean(duo_plus) + def pct_trio_plus_sets(self): trio_plus = [] for run in self.runs: