From 7ed7cc4f3febca688a93e2d57f99303543335374 Mon Sep 17 00:00:00 2001 From: valmik-patel <133670152+valmik-patel@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:25:07 +0530 Subject: [PATCH] Disparity error fix (#136) --- src/aequitas/bias.py | 4 ---- src/aequitas/plotting.py | 5 +++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/aequitas/bias.py b/src/aequitas/bias.py index d1c26511..b612715a 100644 --- a/src/aequitas/bias.py +++ b/src/aequitas/bias.py @@ -266,10 +266,6 @@ def get_disparity_major_group(self, df, original_df, key_columns=None, # always includes label and score significance selected_significance = selected_significance.union({'label_value', 'score'}) - ref_groups_dict = assemble_ref_groups(df, ref_group_flag='_ref_group_value', - specific_measures=selected_significance, - label_score_ref=None) - ref_groups_dict = assemble_ref_groups(df, ref_group_flag='_ref_group_value', specific_measures=selected_significance, label_score_ref=None) diff --git a/src/aequitas/plotting.py b/src/aequitas/plotting.py index 529f9c85..bd26033a 100644 --- a/src/aequitas/plotting.py +++ b/src/aequitas/plotting.py @@ -44,8 +44,9 @@ def assemble_ref_groups(disparities_table, ref_group_flag='_ref_group_value', if len(specific_measures) < 1: raise ValueError("At least one metric must be passed for which to " "find refrence group.") - - specific_measures = specific_measures.union({label_score_ref}) + if label_score_ref: + specific_measures = specific_measures.union({label_score_ref}) + ref_group_cols = {measure + ref_group_flag for measure in specific_measures if measure + ref_group_flag in ref_group_cols}