Skip to content

Commit

Permalink
Add stats for sets of size 2 or more
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 22, 2024
1 parent f6bdd88 commit 7cdffa2
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
41 changes: 28 additions & 13 deletions src/conformist/alpha_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@


class AlphaSelector(OutputDir):
FIGURE_FONTSIZE = 12
FIGURE_WIDTH = 12
FIGURE_HEIGHT = 8
plt.rcParams.update({'font.size': FIGURE_FONTSIZE})

def __init__(self,
prediction_dataset: PredictionDataset,
cop_class,
Expand All @@ -32,6 +37,7 @@ def __init__(self,
self.pcts_empty_sets = []
self.pcts_singleton_sets = []
self.pcts_singleton_or_duo_sets = []
self.pcts_duo_plus_sets = []
self.pcts_trio_plus_sets = []
self.mean_false_negative_rates = []
self.mean_softmax_threshold = []
Expand All @@ -54,6 +60,7 @@ def run(self):
self.pcts_singleton_sets.append(trial.pct_singleton_sets())
self.pcts_singleton_or_duo_sets.append(
trial.pct_singleton_or_duo_sets())
self.pcts_duo_plus_sets.append(trial.pct_duo_plus_sets())
self.pcts_trio_plus_sets.append(trial.pct_trio_plus_sets())
self.mean_false_negative_rates.append(
trial.mean_false_negative_rate())
Expand All @@ -69,7 +76,9 @@ def run_reports(self):

def visualize(self):
# MEAN SET SIZES GRAPH
plt.figure()
plt.figure(figsize=(self.FIGURE_WIDTH,
self.FIGURE_HEIGHT))
plt.tight_layout()

data = pd.DataFrame({
'Alpha': self.alphas,
Expand All @@ -80,7 +89,10 @@ def visualize(self):
plt.savefig(f'{self.output_dir}/alpha_to_mean_set_size.png')

# PERCENT EMPTY/SINGLETON SETS GRAPH
plt.figure()
# MEAN SET SIZES GRAPH
plt.figure(figsize=(self.FIGURE_WIDTH,
self.FIGURE_HEIGHT))
plt.tight_layout()

# Labels
x_label = 'Alpha'
Expand All @@ -90,9 +102,9 @@ def visualize(self):
# Create a DataFrame for the pct_empty_sets and pct_singleton_sets
data = pd.DataFrame({
x_label: self.alphas,
'n = 0': self.pcts_empty_sets,
'n ∈ {1, 2}': self.pcts_singleton_or_duo_sets,
'n ≥ 3': self.pcts_trio_plus_sets
'empty (n = 0)': self.pcts_empty_sets,
'certain (n=1)': self.pcts_singleton_or_duo_sets,
'uncertain (n ≥ 2)': self.pcts_duo_plus_sets
})

# Melt the DataFrame to have the set types as a separate column
Expand All @@ -110,14 +122,16 @@ def visualize(self):
# Get the current x-tick labels
labels = [item.get_text() for item in plt.gca().get_xticklabels()]

target = 'certain (n=1)'

# Draw a horizontal line across the top of the highest orange bar
max_singleton_or_duo_sets = data['n ∈ {1, 2}'].max()
plt.axhline(y=max_singleton_or_duo_sets,
optimal_value = data[target].max()
plt.axhline(y=optimal_value,
color='#cccccc',
linestyle='--')

# Get the index of the label with the highest value
idx = data['n ∈ {1, 2}'].idxmax()
idx = data[target].idxmax()

# Make this label bold
labels[idx] = f'$\\bf{{{labels[idx]}}}$'
Expand All @@ -131,7 +145,9 @@ def visualize(self):
plt.savefig(f'{self.output_dir}/alpha_to_set_sizes.png')

def visualize_lambdas(self):
plt.figure()
plt.figure(figsize=(self.FIGURE_WIDTH,
self.FIGURE_HEIGHT))
plt.tight_layout()

# Only use reasonable alphas
alphas = [0.05, 0.1, 0.15, 0.2, 0.3, 0.4]
Expand All @@ -156,7 +172,7 @@ def visualize_lambdas(self):
plt.text(self.lamhats[a], 0 + padding,
f'{self.lamhats[a]:.2f}',
ha='center', va='bottom',
fontsize=8, color='black',
color='black',
weight='bold')
i += 1

Expand All @@ -179,9 +195,8 @@ def save_summary(self):
'alpha': self.alphas,
'Mean set size': self.mean_set_sizes,
'% sets n=0': self.pcts_empty_sets,
'% sets n={1}': self.pcts_singleton_sets,
'% sets n={1|2}': self.pcts_singleton_or_duo_sets,
'% sets n>=3': self.pcts_trio_plus_sets,
'% sets n=1': self.pcts_singleton_sets,
'% sets n>=2': self.pcts_duo_plus_sets,
'Mean FNR': self.mean_false_negative_rates,
'Mean softmax threshold': self.mean_softmax_threshold
}
Expand Down
10 changes: 8 additions & 2 deletions src/conformist/performance_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ def pct_singleton_or_duo_sets(prediction_sets):
prediction_set in prediction_sets) / \
len(prediction_sets)

def pct_trio_plus_sets(prediction_sets):
return sum(sum(prediction_set) >= 3 for
def _pct_sets_of_min_size(prediction_sets, min_size):
return sum(sum(prediction_set) >= min_size for
prediction_set in prediction_sets) / \
len(prediction_sets)

def pct_duo_plus_sets(prediction_sets):
return PerformanceReport._pct_sets_of_min_size(prediction_sets, 2)

def pct_trio_plus_sets(prediction_sets):
return PerformanceReport._pct_sets_of_min_size(prediction_sets, 3)

def _class_report(self,
items_by_class,
output_file_prefix,
Expand Down
1 change: 0 additions & 1 deletion src/conformist/prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class PredictionDataset(OutputDir):
FIGURE_WIDTH = 12
plt.rcParams.update({'font.size': FIGURE_FONTSIZE})


def __init__(self,
df=None,
predictions_csv=None,
Expand Down
4 changes: 4 additions & 0 deletions src/conformist/validation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def pct_singleton_sets(self):
def pct_singleton_or_duo_sets(self):
return PerformanceReport.pct_singleton_or_duo_sets(self.prediction_sets)

def pct_duo_plus_sets(self):
return PerformanceReport.pct_duo_plus_sets(self.prediction_sets)

def pct_trio_plus_sets(self):
return PerformanceReport.pct_trio_plus_sets(self.prediction_sets)

Expand Down Expand Up @@ -130,6 +133,7 @@ def run_reports(self, base_output_dir):
'pct_empty_sets': self.pct_empty_sets(),
'pct_singleton_sets': self.pct_singleton_sets(),
'pct_singleton_or_duo_sets': self.pct_singleton_or_duo_sets(),
'pct_duo_plus_sets': self.pct_duo_plus_sets(),
'pct_trio_plus_sets': self.pct_trio_plus_sets()
}, index=[0])

Expand Down
6 changes: 6 additions & 0 deletions src/conformist/validation_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def pct_singleton_or_duo_sets(self):
singleton_or_duo.append(run.pct_singleton_or_duo_sets())
return statistics.mean(singleton_or_duo)

def pct_duo_plus_sets(self):
duo_plus = []
for run in self.runs:
duo_plus.append(run.pct_duo_plus_sets())
return statistics.mean(duo_plus)

def pct_trio_plus_sets(self):
trio_plus = []
for run in self.runs:
Expand Down

0 comments on commit 7cdffa2

Please sign in to comment.