From 0dc693620d6ddbeb44c4b4daa62cdcf49bbf29eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?In=C3=AAs=20Silva?= Date: Wed, 31 Jan 2024 16:11:01 +0000 Subject: [PATCH 1/7] Changed metric names from acronym to actual name --- src/aequitas/flow/plots/pareto/plot.py | 2 +- src/aequitas/group.py | 12 +-- .../plot/bubble_concatenation_chart.py | 37 ++++++-- src/aequitas/plot/bubble_disparity_chart.py | 79 +++++++++++----- src/aequitas/plot/bubble_metric_chart.py | 71 +++++++++++---- src/aequitas/plot/summary_chart.py | 90 ++++++++++++------- 6 files changed, 205 insertions(+), 86 deletions(-) diff --git a/src/aequitas/flow/plots/pareto/plot.py b/src/aequitas/flow/plots/pareto/plot.py index 4d9370ed..40c5bbe0 100644 --- a/src/aequitas/flow/plots/pareto/plot.py +++ b/src/aequitas/flow/plots/pareto/plot.py @@ -227,7 +227,7 @@ def bias_audit( model_id: int, dataset: Any, sensitive_attribute: Union[str, list[str]], - metrics: list[str] = ["tpr", "fpr"], + metrics: list[str] = ["TPR", "FPR"], fairness_threshold: float = 1.2, results_path: Union[Path, str] = "examples/experiment_results", reference_groups: Optional[dict[str, str]] = None, diff --git a/src/aequitas/group.py b/src/aequitas/group.py index e9ca7b55..1fca823f 100644 --- a/src/aequitas/group.py +++ b/src/aequitas/group.py @@ -464,7 +464,7 @@ def calculate_disparities( def plot_summary( disparities: pd.DataFrame, - metrics: list[str] = ["fpr", "tpr"], + metrics: list[str] = ["FPR", "TPR"], fairness_threshold: float = 1.25, ): """ @@ -475,7 +475,7 @@ def plot_summary( disparities : pandas.DataFrame Disparities for each group. metrics : list[str], optional - List of metrics to plot. Defaults to ["fpr", "tpr"]. + List of metrics to plot. Defaults to ["FPR", "TPR"]. fairness_threshold : float, optional Threshold to use to determine fairness. Defaults to 1.2. """ @@ -484,7 +484,7 @@ def plot_summary( def plot_disparity( disparities: pd.DataFrame, attribute: str, - metrics: list[str] = ["fpr", "tpr"], + metrics: list[str] = ["FPR", "TPR"], fairness_threshold: float = 1.25, ): """ @@ -495,7 +495,7 @@ def plot_disparity( disparities : pandas.DataFrame Disparities for each group. metrics : list[str], optional - List of metrics to plot. Defaults to ["fpr", "tpr"]. + List of metrics to plot. Defaults to ["FPR", "TPR"]. fairness_threshold : float, optional Threshold to use to determine fairness. Defaults to 1.2. """ @@ -506,7 +506,7 @@ def plot_disparity( def plot_absolute( disparities: pd.DataFrame, attribute: str, - metrics: list[str] = ["fpr", "tpr"], + metrics: list[str] = ["FPR", "TPR"], fairness_threshold: float = 1.25, ): """ @@ -517,7 +517,7 @@ def plot_absolute( disparities : pandas.DataFrame Disparities for each group. metrics : list[str], optional - List of metrics to plot. Defaults to ["fpr", "tpr"]. + List of metrics to plot. Defaults to ["FPR", "TPR"]. fairness_threshold : float, optional Threshold to use to determine fairness. Defaults to 1.2. """ diff --git a/src/aequitas/plot/bubble_concatenation_chart.py b/src/aequitas/plot/bubble_concatenation_chart.py index b8dd0337..118da4c1 100644 --- a/src/aequitas/plot/bubble_concatenation_chart.py +++ b/src/aequitas/plot/bubble_concatenation_chart.py @@ -12,14 +12,28 @@ from aequitas.plot.commons.style.sizes import Concat_Chart from aequitas.plot.commons import initializers as Initializer -# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need it -# (like most annotations), we pass the following dummy dataframe to reduce the complexity of the resulting vega spec. +# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need +# it (like most annotations), we pass the following dummy dataframe to reduce the +# complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) +metric_names = { + "Predictive Equality": "fpr_ratio", + "Equal Opportunity": "tpr_ratio", + "Demographic Parity": "pprev_ratio", + "TPR": "tpr", + "FPR": "fpr", + "FNR": "fnr", + "Accuracy": "accuracy", + "Precision": "precision", +} + def __get_chart_sizes(chart_width): - """Calculates the widths of the disparity and metric charts that make-up the concatenated chart. - The individual widths are calculated based on the provided desired overall chart width.""" + """Calculates the widths of the disparity and metric charts that make-up the + concatenated chart. The individual widths are calculated based on the provided + desired overall chart width. + """ chart_sizes = dict( disparity_chart_width=0.5 * chart_width, metric_chart_width=0.5 * chart_width @@ -59,7 +73,8 @@ def plot_concatenated_bubble_charts( chart_width=Concat_Chart.full_width, accessibility_mode=False, ): - """Draws a concatenation of the disparity bubble chart and the metric values bubble chart, + """Draws a concatenation of the disparity bubble chart and the metric values bubble + chart, of the selected metrics for a given attribute. :param disparity_df: a dataframe generated by the Aequitas Bias class @@ -68,18 +83,21 @@ def plot_concatenated_bubble_charts( :type metrics_list: list :param attribute: an attribute to plot :type attribute: str - :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25 + :param fairness_threshold: a value for the maximum allowed disparity, defaults to + 1.25 :type fairness_threshold: float, optional :param chart_height: a value (in pixels) for the height of the chart :type chart_height: int, optional :param chart_width: a value (in pixels) for the width of the chart :type chart_width: int, optional - :param accessibility_mode: a switch for the display of more accessible visual elements, defaults to False + :param accessibility_mode: a switch for the display of more accessible visual + elements, defaults to False :type accessibility_mode: bool, optional :return: the full disparities chart :rtype: Altair chart object """ + metrics_list = [metric_names[metric] for metric in metrics_list] ( plot_table, @@ -170,10 +188,11 @@ def plot_concatenated_bubble_charts( offset=Chart_Title.offset, ) .properties( - title=attribute.title(), + title=attribute.title(), padding={ "top": Concat_Chart.full_chart_padding, - "bottom": -FONT_SIZE_SMALL * 0.75/3 * len(metrics_list) + Concat_Chart.full_chart_padding, + "bottom": -FONT_SIZE_SMALL * 0.75 / 3 * len(metrics_list) + + Concat_Chart.full_chart_padding, "left": Concat_Chart.full_chart_padding, "right": Concat_Chart.full_chart_padding, }, diff --git a/src/aequitas/plot/bubble_disparity_chart.py b/src/aequitas/plot/bubble_disparity_chart.py index fd35bf98..5fcae4ed 100644 --- a/src/aequitas/plot/bubble_disparity_chart.py +++ b/src/aequitas/plot/bubble_disparity_chart.py @@ -29,15 +29,28 @@ from aequitas.plot.commons import initializers as Initializer from aequitas.plot.commons import labels as Label -# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need it -# (like most annotations), we pass the following dummy dataframe to reduce the complexity of the resulting vega spec. +# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need +# it (like most annotations), we pass the following dummy dataframe to reduce the +# complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) +metric_names = { + "Predictive Equality": "fpr_ratio", + "Equal Opportunity": "tpr_ratio", + "Demographic Parity": "pprev_ratio", + "TPR": "tpr", + "FPR": "fpr", + "FNR": "fnr", + "Accuracy": "accuracy", + "Precision": "precision", +} + def __get_position_scales( plot_table, metrics, fairness_threshold, chart_height, chart_width ): - """Computes the scales for x and y encodings to be used in the disparity bubble chart.""" + """Computes the scales for x and y encodings to be used in the disparity bubble + chart.""" position_scales = dict() @@ -55,7 +68,8 @@ def max_column(x): max_disparities = plot_table[scaled_disparities_col_names].apply(max_column, axis=1) abs_max_disparity = abs(max_column(max_disparities)) - # If fairness_threshold is defined, get max between threshold and max absolute disparity + # If fairness_threshold is defined, get max between threshold and max absolute + # disparity if fairness_threshold is not None: x_domain_limit = math.ceil(max(abs_max_disparity, fairness_threshold)) else: @@ -86,7 +100,8 @@ def __draw_metrics_rules(metrics, scales, concat_chart): orient="left", labelAngle=Metric_Axis.label_angle, # LabelPadding logic: - # Spaces the labels further from the chart if they are part of a concatenated chart + # Spaces the labels further from the chart if they are part of a concatenated + # chart labelPadding=Metric_Axis.label_padding if not concat_chart else Metric_Axis.label_padding_concat_chart, @@ -143,10 +158,12 @@ def list_axis_values(limit, step): def __draw_x_ticks_labels(scales, chart_height): """Draws the numbers in the horizontal axis.""" - # The values to be drawn, we don't want to draw 0 (which corresponds to a ratio of 1) as we later draw an annotation. + # The values to be drawn, we don't want to draw 0 (which corresponds to a ratio of + # 1) as we later draw an annotation. axis_values = __get_x_axis_values(scales["x"].domain) - # Given the semantic of the chart, (how many times smaller or larger) we draw absolute values. + # Given the semantic of the chart, (how many times smaller or larger) we draw + # absolute values. axis_values_labels = [abs(x) + 1 if x != 0 else "=" for x in axis_values] axis_df = pd.DataFrame({"value": axis_values, "label": axis_values_labels}) @@ -236,7 +253,8 @@ def __draw_text_annotations(ref_group, chart_height, x_range): def __draw_reference_rule(ref_group, chart_height, chart_width): - """Draws vertical reference rule where the ratio is the same as the reference group.""" + """Draws vertical reference rule where the ratio is the same as the reference + group.""" reference_rule = ( alt.Chart(DUMMY_DF) @@ -259,7 +277,8 @@ def __draw_reference_rule(ref_group, chart_height, chart_width): def __draw_threshold_rules( threshold_df, scales, chart_height, accessibility_mode=False ): - """Draws threshold rules: red lines that mark the defined fairness_threshold in the chart.""" + """Draws threshold rules: red lines that mark the defined fairness_threshold in the + chart.""" stroke_color = ( Threshold_Rule.stroke_accessible if accessibility_mode @@ -297,7 +316,8 @@ def __draw_threshold_bands( chart_width, accessibility_mode=False, ): - """Draws threshold bands: regions painted red where the metric value is above the defined fairness_threshold.""" + """Draws threshold bands: regions painted red where the metric value is above the + defined fairness_threshold.""" fill_color = ( Threshold_Band.color_accessible if accessibility_mode else Threshold_Band.color ) @@ -341,7 +361,9 @@ def __draw_threshold_text( n_warnings = 0 text_explanation = [] for group, metric in warnings: - y_size = chart_height * (1 - 2 / 3 * Disparity_Chart.padding_y) + Annotation.font_size * Annotation.line_spacing * (n_warnings + 1) + y_size = chart_height * ( + 1 - 2 / 3 * Disparity_Chart.padding_y + ) + Annotation.font_size * Annotation.line_spacing * (n_warnings + 1) explanation_text_warning = warn_text.encode( x=alt.value(0), y=alt.value(y_size), @@ -349,9 +371,9 @@ def __draw_threshold_text( f"Groups {group} have {metric} of 0 (zero). This " "does not allow for the calculation of relative disparities. " "The groups will be absent in respective visualizations.", - ) + ), ) - n_warnings +=1 + n_warnings += 1 text_explanation.append(explanation_text_warning) threshold_text = ( alt.Chart(DUMMY_DF) @@ -368,7 +390,9 @@ def __draw_threshold_text( x=alt.value(0), y=alt.value(chart_height * (1 - 2 / 3 * Disparity_Chart.padding_y)), text=alt.value( - f"The metric value for any group should not be {fairness_threshold} (or more) times smaller or larger than that of the reference group {ref_group}." + f"The metric value for any group should not be {fairness_threshold} (or" + f" more) times smaller or larger than that of the reference group " + f"{ref_group}." ), ) ) @@ -435,7 +459,12 @@ def __draw_bubbles( axis_values = __get_x_axis_values(scales["x"].domain, zero=False) x_axis = alt.Axis( - values=axis_values, ticks=False, domain=False, labels=False, title=None, gridColor=Axis.grid_color + values=axis_values, + ticks=False, + domain=False, + labels=False, + title=None, + gridColor=Axis.grid_color, ) # COLOR @@ -469,8 +498,12 @@ def __draw_bubbles( ) bubble_tooltip_encoding = [ - alt.Tooltip(field="attribute_value", type="nominal", title=Label.SINGLE_GROUP), - alt.Tooltip(field="tooltip_group_size", type="nominal", title=Label.GROUP_SIZE), + alt.Tooltip( + field="attribute_value", type="nominal", title=Label.SINGLE_GROUP + ), + alt.Tooltip( + field="tooltip_group_size", type="nominal", title=Label.GROUP_SIZE + ), alt.Tooltip( field=f"tooltip_disparity_explanation_{metric}", type="nominal", @@ -587,7 +620,7 @@ def get_disparity_bubble_chart_components( chart_height, chart_width, accessibility_mode, - metric_warnings + metric_warnings, ) # ASSEMBLE CHART WITH THRESHOLD @@ -631,18 +664,21 @@ def plot_disparity_bubble_chart( :type metrics_list: list :param attribute: an attribute to plot :type attribute: str - :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25 + :param fairness_threshold: a value for the maximum allowed disparity, defaults to + 1.25 :type fairness_threshold: float, optional :param chart_height: a value (in pixels) for the height of the chart :type chart_height: int, optional :param chart_width: a value (in pixels) for the width of the chart :type chart_width: int, optional - :param accessibility_mode: a switch for the display of more accessible visual elements, defaults to False + :param accessibility_mode: a switch for the display of more accessible visual + elements, defaults to False :type accessibility_mode: bool, optional :return: the full disparities chart :rtype: Altair chart object """ + metrics_list = [metric_names[metric] for metric in metrics_list] ( plot_table, @@ -704,7 +740,8 @@ def plot_disparity_bubble_chart( title=f"Disparities on {attribute.title()}", padding={ "top": Disparity_Chart.full_chart_padding, - "bottom": -FONT_SIZE_SMALL * 2/3 * len(metrics_list) + Disparity_Chart.full_chart_padding, + "bottom": -FONT_SIZE_SMALL * 2 / 3 * len(metrics_list) + + Disparity_Chart.full_chart_padding, "left": Disparity_Chart.full_chart_padding, "right": Disparity_Chart.full_chart_padding, }, diff --git a/src/aequitas/plot/bubble_metric_chart.py b/src/aequitas/plot/bubble_metric_chart.py index 9ee6e9d4..fba4a6e4 100644 --- a/src/aequitas/plot/bubble_metric_chart.py +++ b/src/aequitas/plot/bubble_metric_chart.py @@ -27,13 +27,26 @@ from aequitas.plot.commons import initializers as Initializer from aequitas.plot.commons import labels as Label -# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need it -# (like most annotations), we pass the following dummy dataframe to reduce the complexity of the resulting vega spec. +# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need +# it (like most annotations), we pass the following dummy dataframe to reduce the +# complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) +metric_names = { + "Predictive Equality": "fpr_ratio", + "Equal Opportunity": "tpr_ratio", + "Demographic Parity": "pprev_ratio", + "TPR": "tpr", + "FPR": "fpr", + "FNR": "fnr", + "Accuracy": "accuracy", + "Precision": "precision", +} + def __get_position_scales(metrics, chart_height, chart_width, concat_chart): - """Computes the scales for x and y encodings to be used in the metric bubble chart.""" + """Computes the scales for x and y encodings to be used in the metric bubble + chart.""" position_scales = dict() @@ -138,7 +151,8 @@ def __draw_x_ticks_labels(scales, chart_height): def __draw_threshold_rules(threshold_df, scales, position, accessibility_mode=False): - """Draws fairness threshold rules: red lines that mark the defined fairness threshold in the chart.""" + """Draws fairness threshold rules: red lines that mark the defined fairness + threshold in the chart.""" stroke_color = ( Threshold_Rule.stroke_accessible if accessibility_mode @@ -167,7 +181,8 @@ def __draw_threshold_rules(threshold_df, scales, position, accessibility_mode=Fa def __draw_threshold_bands(threshold_df, scales, accessibility_mode=False): - """Draws fairness threshold bands: regions painted red where the metric value is above the defined fairness threshold.""" + """Draws fairness threshold bands: regions painted red where the metric value is + above the defined fairness threshold.""" fill_color = ( Threshold_Band.color_accessible if accessibility_mode else Threshold_Band.color @@ -216,17 +231,20 @@ def __draw_threshold_text( n_warnings = 0 text_explanation = [] for group, metric in warnings: - y_size = chart_height * (1 - 2 / 3 * Metric_Chart.padding_y) + Annotation.font_size * Annotation.line_spacing * (n_warnings + 1) + y_size = chart_height * ( + 1 - 2 / 3 * Metric_Chart.padding_y + ) + Annotation.font_size * Annotation.line_spacing * (n_warnings + 1) explanation_text_warning = warn_text.encode( x=alt.value(0), y=alt.value(y_size), text=alt.value( f"Groups {group} have {metric} of 0 (zero). This " "does not allow for the calculation of relative disparities. " - "The tooltip for these groups in this plot does not have relative disparities.", - ) + "The tooltip for these groups in this plot does not have relative" + " disparities.", + ), ) - n_warnings +=1 + n_warnings += 1 text_explanation.append(explanation_text_warning) threshold_text = ( alt.Chart(DUMMY_DF) @@ -243,7 +261,9 @@ def __draw_threshold_text( x=alt.value(0), y=alt.value(chart_height * (1 - 2 / 3 * Metric_Chart.padding_y)), text=alt.value( - f"The metric value for any group should not be {fairness_threshold} (or more) times smaller or larger than that of the reference group {ref_group}." + f"The metric value for any group should not be {fairness_threshold} (or" + f" more) times smaller or larger than that of the reference group " + f"{ref_group}." ), ) ) @@ -333,7 +353,12 @@ def __draw_bubbles( # X AXIS GRIDLINES axis_values = [0.25, 0.5, 0.75] x_axis = alt.Axis( - values=axis_values, ticks=False, domain=False, labels=False, title=None, gridColor=Axis.grid_color + values=axis_values, + ticks=False, + domain=False, + labels=False, + title=None, + gridColor=Axis.grid_color, ) # COLOR @@ -367,8 +392,12 @@ def __draw_bubbles( ) bubble_tooltip_encoding = [ - alt.Tooltip(field="attribute_value", type="nominal", title=Label.SINGLE_GROUP), - alt.Tooltip(field="tooltip_group_size", type="nominal", title=Label.GROUP_SIZE), + alt.Tooltip( + field="attribute_value", type="nominal", title=Label.SINGLE_GROUP + ), + alt.Tooltip( + field="tooltip_group_size", type="nominal", title=Label.GROUP_SIZE + ), alt.Tooltip( field=f"tooltip_disparity_explanation_{metric}", type="nominal", @@ -481,7 +510,7 @@ def get_metric_bubble_chart_components( chart_height, concat_chart, accessibility_mode, - metric_warnings + metric_warnings, ) # ASSEMBLE CHART WITH THRESHOLD main_chart = ( @@ -507,7 +536,8 @@ def plot_metric_bubble_chart( chart_width=Metric_Chart.full_width, accessibility_mode=False, ): - """Draws bubble chart to visualize the values of the selected metrics for a given attribute. + """Draws bubble chart to visualize the values of the selected metrics for a given + attribute. :param disparity_df: a dataframe generated by the Aequitas Bias class :type disparity_df: pandas.core.frame.DataFrame @@ -515,18 +545,22 @@ def plot_metric_bubble_chart( :type metrics_list: list :param attribute: an attribute to plot :type attribute: str - :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25 + :param fairness_threshold: a value for the maximum allowed disparity, defaults to + 1.25 :type fairness_threshold: float, optional :param chart_height: a value (in pixels) for the height of the chart :type chart_height: int, optional :param chart_width: a value (in pixels) for the width of the chart :type chart_width: int, optional - :param accessibility_mode: a switch for the display of more accessible visual elements, defaults to False + :param accessibility_mode: a switch for the display of more accessible visual + elements, defaults to False :type accessibility_mode: bool, optional :return: the full metrics chart :rtype: Altair chart object """ + metrics_list = [metric_names[metric] for metric in metrics_list] + ( plot_table, metrics, @@ -577,7 +611,8 @@ def plot_metric_bubble_chart( # padding=Metric_Chart.full_chart_padding, padding={ "top": Metric_Chart.full_chart_padding, - "bottom": -FONT_SIZE_SMALL * 1.25/3 * len(metrics_list) + Metric_Chart.full_chart_padding, + "bottom": -FONT_SIZE_SMALL * 1.25 / 3 * len(metrics_list) + + Metric_Chart.full_chart_padding, "left": Metric_Chart.full_chart_padding, "right": Metric_Chart.full_chart_padding, }, diff --git a/src/aequitas/plot/summary_chart.py b/src/aequitas/plot/summary_chart.py index 75d2b76b..5e01d16c 100644 --- a/src/aequitas/plot/summary_chart.py +++ b/src/aequitas/plot/summary_chart.py @@ -25,10 +25,22 @@ from aequitas.plot.commons import validators as Validator from aequitas.plot.commons import labels as Label -# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need it -# (like most annotations), we pass the following dummy dataframe to reduce the complexity of the resulting vega spec. +# Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need +# it (like most annotations), we pass the following dummy dataframe to reduce the +# complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) +metric_names = { + "Predictive Equality": "fpr_ratio", + "Equal Opportunity": "tpr_ratio", + "Demographic Parity": "pprev_ratio", + "TPR": "tpr", + "FPR": "fpr", + "FNR": "fnr", + "Accuracy": "accuracy", + "Precision": "precision", +} + def __get_scales(max_num_groups): """Creates an Altair scale for the color of the parity test result, and another @@ -52,9 +64,9 @@ def __get_scales(max_num_groups): def __get_size_constants( chart_height, chart_width, num_attributes, num_metrics, max_num_groups ): - """Calculates the heights, widths and spacings of the components of the summary chart - based on the provided desired overall chart height and width, as well as the number of - attributes (columns) and metrics (lines).""" + """Calculates the heights, widths and spacings of the components of the summary + chart based on the provided desired overall chart height and width, as well as the + number of attributes (columns) and metrics (lines).""" size_constants = dict( # Chart sizes @@ -65,11 +77,11 @@ def __get_size_constants( column_spacing=0.15 * chart_width / num_attributes, column_width=Summary_Chart.column_width_ratio * chart_width / num_attributes, # Circle size - ## Conditional definition of the size where for each additional unit in - ## max_num_groups, we subtract 25 squared pixels from the area of the - ## circle, which has the base value of 350 for 0 groups. From max_num_groups - ## equal to 13 or more, we keep the size at the minimum value of 25 to - ## make sure the circles are visible. + # Conditional definition of the size where for each additional unit in + # max_num_groups, we subtract 25 squared pixels from the area of the + # circle, which has the base value of 350 for 0 groups. From max_num_groups + # equal to 13 or more, we keep the size at the minimum value of 25 to + # make sure the circles are visible. group_circle_size=-25 * max_num_groups + 350 if max_num_groups < 13 else 25, ) return size_constants @@ -163,9 +175,10 @@ def __draw_metric_line_titles(metrics, size_constants): ) # EMPTY CORNER SPACE - # To make sure that the attribute columns align properly with the title column, we need to create a blank - # space of the same size of the attribute titles. For this purpose, we use the same function (__draw_attribute_title) - # and pass in an empty string so that nothing is actually drawn. + # To make sure that the attribute columns align properly with the title column, we + # need to create a blank space of the same size of the attribute titles. For this + # purpose, we use the same function (__draw_attribute_title) and pass in an empty + # string so that nothing is actually drawn. top_left_corner_space = __draw_attribute_title( "", size_constants["metric_titles_width"], size_constants ) @@ -182,7 +195,8 @@ def __draw_metric_line_titles(metrics, size_constants): def __get_parity_result_variable(row, metric, fairness_threshold): - """Creates parity test result variable for each provided row, separating the Reference group from the passing ones.""" + """Creates parity test result variable for each provided row, separating the + Reference group from the passing ones.""" if row["attribute_value"] == row["ref_group_value"]: return "Reference" elif abs(row[f"{metric}_disparity_scaled"]) < fairness_threshold - 1: @@ -192,7 +206,8 @@ def __get_parity_result_variable(row, metric, fairness_threshold): def __draw_parity_result_text(parity_result, color_scale): - """Draws the uppercased text result of the provided parity test (Pass, Fail or Reference), + """Draws the uppercased text result of the provided parity test (Pass, Fail or + Reference), color-coded according to the provided Altair scale.""" return ( @@ -211,6 +226,7 @@ def __draw_parity_result_text(parity_result, color_scale): "parity_result:O", scale=color_scale, legend=alt.Legend( + title=Label.TEST, title=Label.TEST, padding=Legend.margin, offset=0, @@ -222,7 +238,8 @@ def __draw_parity_result_text(parity_result, color_scale): def __draw_population_bar(population_bar_df, metric, color_scale): - """Draws a stacked bar of the sum of the percentage of population of the groups that obtained each result for the parity test.""" + """Draws a stacked bar of the sum of the percentage of population of the groups + that obtained each result for the parity test.""" population_bar_tooltips = [ alt.Tooltip(field=f"{metric}_parity_result", type="nominal", title="Parity"), alt.Tooltip( @@ -258,7 +275,8 @@ def __draw_population_bar(population_bar_df, metric, color_scale): def __draw_group_circles(plot_df, metric, scales, size_constants): """Draws a circle for each group, color-coded by the result of the parity test. - The groups are spread around the central reference group according to their disparity. + The groups are spread around the central reference group according to their + disparity. """ circle_tooltip_encoding = [ @@ -348,7 +366,8 @@ def __draw_parity_test_explanation(fairness_threshold, x_position, warnings=()): x=alt.value(x_position), y=alt.value(0), text=alt.value( - f"For a group to pass the parity test its disparity to the reference group cannot exceed the fairness threshold ({fairness_threshold})." + f"For a group to pass the parity test its disparity to the reference group " + f"cannot exceed the fairness threshold ({fairness_threshold})." ), ) @@ -356,7 +375,8 @@ def __draw_parity_test_explanation(fairness_threshold, x_position, warnings=()): x=alt.value(x_position), y=alt.value(Annotation.font_size * Annotation.line_spacing), text=alt.value( - f"An attribute passes the parity test for a given metric if all its groups pass the test." + "An attribute passes the parity test for a given metric if all its groups " + "pass the test." ), ) text_explanation.append(explanation_text_group) @@ -401,7 +421,8 @@ def __create_population_bar_df(attribute_df, metric): def __create_group_rank_variable(attribute_df, metric): - """Creates the disparity rank variable for the given metric, centered around 0 (the Reference Group's value).""" + """Creates the disparity rank variable for the given metric, centered around 0 (the + Reference Group's value).""" # RANK attribute_df[f"{metric}_disparity_rank"] = attribute_df[ @@ -454,9 +475,11 @@ def __create_tooltip_variables(attribute_df, metric, fairness_threshold): def __create_disparity_variables(attribute_df, metric, fairness_threshold): - """Creates scaled disparity, parity test result & disparity explanation tooltip variables.""" + """Creates scaled disparity, parity test result & disparity explanation tooltip + variables.""" # Check if any group has a disparity of 0. - # These values would potentially break the plots, and will raise warnings to the user. + # These values would potentially break the plots, and will raise warnings to the + # user. zero_metric_groups = attribute_df[attribute_df[f"{metric}_disparity"] == 0] zero_values = zero_metric_groups["attribute_value"].values warning = None @@ -478,10 +501,12 @@ def __create_disparity_variables(attribute_df, metric, fairness_threshold): return warning + def __get_attribute_column( attribute_df, metrics, scales, attribute, size_constants, fairness_threshold ): - """Returns a vertical concatenation of all elements of all metrics for each attribute's column.""" + """Returns a vertical concatenation of all elements of all metrics for each + attribute's column.""" metric_summary = [] metric_warnings = [] @@ -496,8 +521,8 @@ def __get_attribute_column( __create_group_rank_variable(attribute_df, metric) # PARITY RESULT TEXT - ## The parity result is equal to the "worst" of each group's results - ## If one group fails the parity test, the whole metric fails (for that attribute) + # The parity result is equal to the "worst" of each group's results + # If one group fails the parity test, the whole metric fails (for that attr.) parity_result = attribute_df.loc[ attribute_df[f"{metric}_parity_result"] != "Reference" ][f"{metric}_parity_result"].min() @@ -549,17 +574,19 @@ def plot_summary_chart( chart_height=None, chart_width=None, ): - """Draws chart that summarizes the parity results for the provided metrics across the existing attributes. - This includes an overall result, the specific results by each attribute's groups as well as the percentage - of population by result. + """Draws chart that summarizes the parity results for the provided metrics across + the existing attributes. This includes an overall result, the specific results by + each attribute's groups as well as the percentage of population by result. :param disparity_df: a dataframe generated by the Aequitas Bias class :type disparity_df: pandas.core.frame.DataFrame :param metrics_list: a list of the metrics of interest :type metrics_list: list - :param attributes_list: a list of the attributes of interest, defaults to using all in the dataframe + :param attributes_list: a list of the attributes of interest, defaults to using all + in the dataframe :type attributes_list: list, optional - :param fairness_threshold: a value for the maximum allowed disparity, defaults to 1.25 + :param fairness_threshold: a value for the maximum allowed disparity, defaults to + 1.25 :type fairness_threshold: float, optional :param chart_height: a value (in pixels) for the height of the chart :type chart_height: int, optional @@ -569,8 +596,9 @@ def plot_summary_chart( :return: the full summary chart :rtype: Altair chart object """ + metrics_list = [metric_names[metric] for metric in metrics_list] - ## If a specific list of attributes was not passed, use all from df + # If a specific list of attributes was not passed, use all from df ( metrics, attributes, From d8ab9c9a64bb75e6c0c791c72ed68e804c827d45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?In=C3=AAs=20Silva?= Date: Thu, 1 Feb 2024 11:42:36 +0000 Subject: [PATCH 2/7] Metric name cleanup in flow --- src/aequitas/flow/plots/bootstrap/plot.py | 25 ++++----------- .../flow/plots/bootstrap/visualize.py | 16 ++-------- src/aequitas/flow/plots/pareto/plot.py | 31 ++++++------------- src/aequitas/flow/plots/pareto/visualize.py | 12 +++---- src/aequitas/flow/utils/metrics.py | 20 ++++++++++++ 5 files changed, 44 insertions(+), 60 deletions(-) create mode 100644 src/aequitas/flow/utils/metrics.py diff --git a/src/aequitas/flow/plots/bootstrap/plot.py b/src/aequitas/flow/plots/bootstrap/plot.py index f1806fc6..bab91ee8 100644 --- a/src/aequitas/flow/plots/bootstrap/plot.py +++ b/src/aequitas/flow/plots/bootstrap/plot.py @@ -1,8 +1,9 @@ -from typing import Literal, Optional +from typing import Optional import numpy as np from ...utils.evaluation import bootstrap_hyperparameters +from ...utils.metrics import METRIC_NAMES, FAIRNESS_METRIC, PERFORMANCE_METRIC from ...evaluation import Result @@ -14,27 +15,13 @@ } -metrics = { - "Predictive Equality": "fpr_ratio", - "Equal Opportunity": "tpr_ratio", - "Demographic Parity": "pprev_ratio", - "TPR": "tpr", - "FPR": "fpr", - "FNR": "fnr", - "Accuracy": "accuracy", - "Precision": "precision", -} - - class Plot: def __init__( self, results: dict[str, dict[str, Result]], dataset: str, - fairness_metric: Literal[ - "Predictive Equality", "Equal Opportunity", "Demographic Parity" - ], - performance_metric: Literal["TPR", "FPR", "FNR", "Accuracy", "Precision"], + fairness_metric: FAIRNESS_METRIC, + performance_metric: PERFORMANCE_METRIC, method: Optional[str] = None, confidence_intervals: float = 0.95, **kwargs, @@ -64,8 +51,8 @@ def __init__( for key, value in DEFAULT_KWARGS.items(): if key not in self.kwargs: self.kwargs[key] = value - self.kwargs["fairness_metric"] = metrics[fairness_metric] - self.kwargs["performance_metric"] = metrics[performance_metric] + self.kwargs["fairness_metric"] = METRIC_NAMES[fairness_metric] + self.kwargs["performance_metric"] = METRIC_NAMES[performance_metric] self.bootstrap_results = {} if isinstance(self.kwargs["alpha_points"], np.ndarray): self.x = self.kwargs["alpha_points"] diff --git a/src/aequitas/flow/plots/bootstrap/visualize.py b/src/aequitas/flow/plots/bootstrap/visualize.py index fba7c73b..3764738b 100644 --- a/src/aequitas/flow/plots/bootstrap/visualize.py +++ b/src/aequitas/flow/plots/bootstrap/visualize.py @@ -4,6 +4,7 @@ from .plot import Plot +from ...utils.metrics import METRIC_NAMES sns.set() sns.set_style("whitegrid", {"grid.linestyle": "--"}) @@ -53,22 +54,11 @@ "grid_search_folktables", ] -metrics_names = { - "Predictive Equality": "Pred. Eq.", - "Equal Opportunity": "Eq. Opp.", - "Demographic Parity": "Dem. Par.", - "TPR": "TPR", - "FPR": "FPR", - "FNR": "FNR", - "Accuracy": "Acc.", - "Precision": "Prec.", -} - def visualize(plot: Plot): # define the name of the metrics for plot - perf_metric_plot = metrics_names[plot.performance_metric] - fair_metric_plot = metrics_names[plot.fairness_metric] + perf_metric_plot = METRIC_NAMES[plot.performance_metric] + fair_metric_plot = METRIC_NAMES[plot.fairness_metric] x = plot.x diff --git a/src/aequitas/flow/plots/pareto/plot.py b/src/aequitas/flow/plots/pareto/plot.py index 40c5bbe0..3aaa6df2 100644 --- a/src/aequitas/flow/plots/pareto/plot.py +++ b/src/aequitas/flow/plots/pareto/plot.py @@ -8,6 +8,7 @@ from ....bias import Bias from ....group import Group from ....plot import summary, disparity +from ...utils.metrics import FAIRNESS_METRIC, PERFORMANCE_METRIC _names = { @@ -54,10 +55,12 @@ class Plot: Name of the dataset to be used in the Pareto plot. method : union[str, list], optional Name of the method to plot. If none, all methods will be plotted. - fairness_metric : {"Predictive Equality", "Equal Opportunity", "Demographic Parity"} - The default fairness metric to use in the Pareto plot. - performance_metric : {"TPR", "FPR", "FNR", "Accuracy", "Precision"} - The default performance metric to use in the Pareto plot. + fairness_metric : str + The default fairness metric to use in the Pareto plot. Possible values + are defined in aequitas.flow.utils.metrics + performance_metric : str + The default performance metric to use in the Pareto plot. Possible values + are defined in aequitas.flow.utils.metrics alpha : float, optional The alpha value to use in the Pareto plot. direction : {"minimize", "maximize"}, optional @@ -68,10 +71,8 @@ def __init__( self, results: dict[str, dict[str, Result]], dataset: str, - fairness_metric: Literal[ - "Predictive Equality", "Equal Opportunity", "Demographic Parity" - ], - performance_metric: Literal["TPR", "FPR", "FNR", "Accuracy", "Precision"], + fairness_metric: FAIRNESS_METRIC, + performance_metric: PERFORMANCE_METRIC, method: Optional[Union[str, list]] = None, alpha: float = 0.5, direction: Literal["minimize", "maximize"] = "maximize", @@ -109,18 +110,6 @@ def __init__( self.performance_metric = performance_metric self.alpha = alpha self.direction = direction - self.available_fairness_metrics = { - "Predictive Equality", - "Equal Opportunity", - "Demographic Parity", - } # Hardcoded for now - self.available_performance_metrics = [ - "TPR", - "FPR", - "FNR", - "Accuracy", - "Precision", - ] self._best_model_idx: int = 0 @property @@ -302,7 +291,7 @@ def disparities( model_id: int, dataset: Any, sensitive_attribute: Union[str, list[str]], - metrics: list[str] = ["tpr", "fpr"], + metrics: list[str] = ["TPR", "FPR"], fairness_threshold: float = 1.2, results_path: Union[Path, str] = "examples/experiment_results", reference_groups: Optional[dict[str, str]] = None, diff --git a/src/aequitas/flow/plots/pareto/visualize.py b/src/aequitas/flow/plots/pareto/visualize.py index 9c15b988..2c821f23 100644 --- a/src/aequitas/flow/plots/pareto/visualize.py +++ b/src/aequitas/flow/plots/pareto/visualize.py @@ -8,6 +8,7 @@ import pkg_resources from .plot import Plot +from ...utils.metrics import FAIRNESS_METRICS, PERFORMANCE_METRICS # NumPy data types are not JSON serializable. This custom JSON encoder will @@ -90,12 +91,9 @@ def visualize(wrapper: Plot, mode="display", save_path=None, pareto_only=False): if pareto_only: wrapper_results_flat = wrapper_results_flat[wrapper_results_flat["is_pareto"]] - fairness_metrics = list(wrapper.available_fairness_metrics) - performance_metrics = list(wrapper.available_performance_metrics) - filtered_results = wrapper_results_flat[ - fairness_metrics - + performance_metrics + FAIRNESS_METRICS + + PERFORMANCE_METRICS + ["model_id", "hyperparams", "is_pareto"] ] @@ -108,8 +106,8 @@ def visualize(wrapper: Plot, mode="display", save_path=None, pareto_only=False): "recommended_model": wrapper.best_model_details, "optimized_fairness_metric": wrapper.fairness_metric, "optimized_performance_metric": wrapper.performance_metric, - "fairness_metrics": fairness_metrics, - "performance_metrics": performance_metrics, + "fairness_metrics": FAIRNESS_METRICS, + "performance_metrics": PERFORMANCE_METRICS, "tuner_type": "Random Search", # Hardcoded for now "alpha": wrapper.alpha, } diff --git a/src/aequitas/flow/utils/metrics.py b/src/aequitas/flow/utils/metrics.py new file mode 100644 index 00000000..1dce3953 --- /dev/null +++ b/src/aequitas/flow/utils/metrics.py @@ -0,0 +1,20 @@ +from typing import Literal + +FAIRNESS_METRICS = ["Predictive Equality", "Equal Opportunity", "Demographic Parity"] +PERFORMANCE_METRICS = ["TPR", "FPR", "FNR", "Accuracy", "Precision"] + +METRIC_NAMES = { + "Predictive Equality": "fpr_ratio", + "Equal Opportunity": "tpr_ratio", + "Demographic Parity": "pprev_ratio", + "TPR": "tpr", + "FPR": "fpr", + "FNR": "fnr", + "Accuracy": "accuracy", + "Precision": "precision", +} + +FAIRNESS_METRIC = Literal[ + "Predictive Equality", "Equal Opportunity", "Demographic Parity" +] +PERFORMANCE_METRIC = Literal["TPR", "FPR", "FNR", "Accuracy", "Precision"] From 63cb171acea8f9a5ecc700fb2746baf7947bd720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?In=C3=AAs=20Silva?= Date: Thu, 1 Feb 2024 12:17:20 +0000 Subject: [PATCH 3/7] Reverted changes in aequitas plots --- src/aequitas/flow/plots/pareto/plot.py | 2 +- src/aequitas/group.py | 12 ++++++------ src/aequitas/plot/bubble_concatenation_chart.py | 2 -- src/aequitas/plot/bubble_disparity_chart.py | 13 ------------- src/aequitas/plot/bubble_metric_chart.py | 13 ------------- src/aequitas/plot/summary_chart.py | 13 ------------- 6 files changed, 7 insertions(+), 48 deletions(-) diff --git a/src/aequitas/flow/plots/pareto/plot.py b/src/aequitas/flow/plots/pareto/plot.py index 3aaa6df2..4e32e6ac 100644 --- a/src/aequitas/flow/plots/pareto/plot.py +++ b/src/aequitas/flow/plots/pareto/plot.py @@ -216,7 +216,7 @@ def bias_audit( model_id: int, dataset: Any, sensitive_attribute: Union[str, list[str]], - metrics: list[str] = ["TPR", "FPR"], + metrics: list[str] = ["tpr", "fpr"], fairness_threshold: float = 1.2, results_path: Union[Path, str] = "examples/experiment_results", reference_groups: Optional[dict[str, str]] = None, diff --git a/src/aequitas/group.py b/src/aequitas/group.py index 1fca823f..e9ca7b55 100644 --- a/src/aequitas/group.py +++ b/src/aequitas/group.py @@ -464,7 +464,7 @@ def calculate_disparities( def plot_summary( disparities: pd.DataFrame, - metrics: list[str] = ["FPR", "TPR"], + metrics: list[str] = ["fpr", "tpr"], fairness_threshold: float = 1.25, ): """ @@ -475,7 +475,7 @@ def plot_summary( disparities : pandas.DataFrame Disparities for each group. metrics : list[str], optional - List of metrics to plot. Defaults to ["FPR", "TPR"]. + List of metrics to plot. Defaults to ["fpr", "tpr"]. fairness_threshold : float, optional Threshold to use to determine fairness. Defaults to 1.2. """ @@ -484,7 +484,7 @@ def plot_summary( def plot_disparity( disparities: pd.DataFrame, attribute: str, - metrics: list[str] = ["FPR", "TPR"], + metrics: list[str] = ["fpr", "tpr"], fairness_threshold: float = 1.25, ): """ @@ -495,7 +495,7 @@ def plot_disparity( disparities : pandas.DataFrame Disparities for each group. metrics : list[str], optional - List of metrics to plot. Defaults to ["FPR", "TPR"]. + List of metrics to plot. Defaults to ["fpr", "tpr"]. fairness_threshold : float, optional Threshold to use to determine fairness. Defaults to 1.2. """ @@ -506,7 +506,7 @@ def plot_disparity( def plot_absolute( disparities: pd.DataFrame, attribute: str, - metrics: list[str] = ["FPR", "TPR"], + metrics: list[str] = ["fpr", "tpr"], fairness_threshold: float = 1.25, ): """ @@ -517,7 +517,7 @@ def plot_absolute( disparities : pandas.DataFrame Disparities for each group. metrics : list[str], optional - List of metrics to plot. Defaults to ["FPR", "TPR"]. + List of metrics to plot. Defaults to ["fpr", "tpr"]. fairness_threshold : float, optional Threshold to use to determine fairness. Defaults to 1.2. """ diff --git a/src/aequitas/plot/bubble_concatenation_chart.py b/src/aequitas/plot/bubble_concatenation_chart.py index 118da4c1..e62b3ad3 100644 --- a/src/aequitas/plot/bubble_concatenation_chart.py +++ b/src/aequitas/plot/bubble_concatenation_chart.py @@ -97,8 +97,6 @@ def plot_concatenated_bubble_charts( :return: the full disparities chart :rtype: Altair chart object """ - metrics_list = [metric_names[metric] for metric in metrics_list] - ( plot_table, metrics, diff --git a/src/aequitas/plot/bubble_disparity_chart.py b/src/aequitas/plot/bubble_disparity_chart.py index 5fcae4ed..ea2a45a8 100644 --- a/src/aequitas/plot/bubble_disparity_chart.py +++ b/src/aequitas/plot/bubble_disparity_chart.py @@ -34,17 +34,6 @@ # complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) -metric_names = { - "Predictive Equality": "fpr_ratio", - "Equal Opportunity": "tpr_ratio", - "Demographic Parity": "pprev_ratio", - "TPR": "tpr", - "FPR": "fpr", - "FNR": "fnr", - "Accuracy": "accuracy", - "Precision": "precision", -} - def __get_position_scales( plot_table, metrics, fairness_threshold, chart_height, chart_width @@ -678,8 +667,6 @@ def plot_disparity_bubble_chart( :return: the full disparities chart :rtype: Altair chart object """ - metrics_list = [metric_names[metric] for metric in metrics_list] - ( plot_table, metrics, diff --git a/src/aequitas/plot/bubble_metric_chart.py b/src/aequitas/plot/bubble_metric_chart.py index fba4a6e4..d5d52715 100644 --- a/src/aequitas/plot/bubble_metric_chart.py +++ b/src/aequitas/plot/bubble_metric_chart.py @@ -32,17 +32,6 @@ # complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) -metric_names = { - "Predictive Equality": "fpr_ratio", - "Equal Opportunity": "tpr_ratio", - "Demographic Parity": "pprev_ratio", - "TPR": "tpr", - "FPR": "fpr", - "FNR": "fnr", - "Accuracy": "accuracy", - "Precision": "precision", -} - def __get_position_scales(metrics, chart_height, chart_width, concat_chart): """Computes the scales for x and y encodings to be used in the metric bubble @@ -559,8 +548,6 @@ def plot_metric_bubble_chart( :return: the full metrics chart :rtype: Altair chart object """ - metrics_list = [metric_names[metric] for metric in metrics_list] - ( plot_table, metrics, diff --git a/src/aequitas/plot/summary_chart.py b/src/aequitas/plot/summary_chart.py index 5e01d16c..e7617136 100644 --- a/src/aequitas/plot/summary_chart.py +++ b/src/aequitas/plot/summary_chart.py @@ -30,17 +30,6 @@ # complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) -metric_names = { - "Predictive Equality": "fpr_ratio", - "Equal Opportunity": "tpr_ratio", - "Demographic Parity": "pprev_ratio", - "TPR": "tpr", - "FPR": "fpr", - "FNR": "fnr", - "Accuracy": "accuracy", - "Precision": "precision", -} - def __get_scales(max_num_groups): """Creates an Altair scale for the color of the parity test result, and another @@ -596,8 +585,6 @@ def plot_summary_chart( :return: the full summary chart :rtype: Altair chart object """ - metrics_list = [metric_names[metric] for metric in metrics_list] - # If a specific list of attributes was not passed, use all from df ( metrics, From 232d996ef8da1152f6b5be243d3d90173c8fda3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Jesus?= Date: Fri, 2 Feb 2024 09:48:04 +0000 Subject: [PATCH 4/7] Update previous plots to support new naming scheme --- src/aequitas/plot/bubble_disparity_chart.py | 16 +++++++++------- src/aequitas/plot/bubble_metric_chart.py | 16 +++++++++------- src/aequitas/plot/commons/metrics.py | 20 ++++++++++++++++++++ src/aequitas/plot/summary_chart.py | 8 ++++++-- 4 files changed, 44 insertions(+), 16 deletions(-) create mode 100644 src/aequitas/plot/commons/metrics.py diff --git a/src/aequitas/plot/bubble_disparity_chart.py b/src/aequitas/plot/bubble_disparity_chart.py index ea2a45a8..33f6e4e9 100644 --- a/src/aequitas/plot/bubble_disparity_chart.py +++ b/src/aequitas/plot/bubble_disparity_chart.py @@ -28,6 +28,8 @@ from aequitas.plot.commons.style.sizes import Disparity_Chart from aequitas.plot.commons import initializers as Initializer from aequitas.plot.commons import labels as Label +from aequitas.plot.commons.metrics import possible_metrics, display_name + # Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need # it (like most annotations), we pass the following dummy dataframe to reduce the @@ -71,7 +73,7 @@ def max_column(x): y_range = get_chart_size_range(chart_height, Disparity_Chart.padding_y) if chart_height < 300: y_range[0] = 30 - y_domain = [metric.upper() for metric in metrics] + y_domain = [display_name.get(metric.lower(), metric).upper() for metric in metrics] position_scales["y"] = alt.Scale(domain=y_domain, range=y_range) return position_scales @@ -81,7 +83,7 @@ def __draw_metrics_rules(metrics, scales, concat_chart): """Draws an horizontal rule and the left-hand side label for each metric. The groups' bubbles will be positioned on this horizontal rule.""" - metrics_labels = [metric.upper() for metric in metrics] + metrics_labels = [display_name.get(metric.lower(), metric).upper() for metric in metrics] metrics_axis = alt.Axis( domain=False, @@ -118,7 +120,6 @@ def __draw_metrics_rules(metrics, scales, concat_chart): x2="x2:Q", ) ) - return metrics_rules @@ -502,7 +503,7 @@ def __draw_bubbles( field=f"{metric}", type="quantitative", format=".2f", - title=f"{metric}".upper(), + title=f"{display_name.get(metric, metric)}".upper(), ), ] @@ -511,7 +512,7 @@ def __draw_bubbles( bubble_centers += ( alt.Chart(plot_table) - .transform_calculate(metric_variable=f"'{metric.upper()}'") + .transform_calculate(metric_variable=f"'{display_name.get(metric.lower(), metric).upper()}'") .mark_point(filled=True, size=Bubble.center_size, cursor=Bubble.cursor) .encode( x=alt.X(f"{metric}_disparity_scaled:Q", scale=scales["x"], axis=x_axis), @@ -531,7 +532,7 @@ def __draw_bubbles( bubble_areas += ( alt.Chart(plot_table) .mark_circle(opacity=Bubble.opacity, cursor=Bubble.cursor) - .transform_calculate(metric_variable=f"'{metric.upper()}'") + .transform_calculate(metric_variable=f"'{display_name.get(metric.lower(), metric).upper()}'") .encode( x=alt.X(f"{metric}_disparity_scaled:Q", scale=scales["x"], axis=x_axis), y=alt.Y("metric_variable:N", scale=scales["y"], axis=no_axis()), @@ -667,6 +668,7 @@ def plot_disparity_bubble_chart( :return: the full disparities chart :rtype: Altair chart object """ + metrics = [possible_metrics.get(metric.lower(), metric) for metric in metrics_list] ( plot_table, metrics, @@ -677,7 +679,7 @@ def plot_disparity_bubble_chart( selection, ) = Initializer.prepare_bubble_chart( disparity_df, - metrics_list, + metrics, attribute, fairness_threshold, chart_height, diff --git a/src/aequitas/plot/bubble_metric_chart.py b/src/aequitas/plot/bubble_metric_chart.py index d5d52715..b54594c6 100644 --- a/src/aequitas/plot/bubble_metric_chart.py +++ b/src/aequitas/plot/bubble_metric_chart.py @@ -26,6 +26,7 @@ from aequitas.plot.commons.style.text import FONT, FONT_SIZE_SMALL from aequitas.plot.commons import initializers as Initializer from aequitas.plot.commons import labels as Label +from aequitas.plot.commons.metrics import display_name, possible_metrics # Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need # it (like most annotations), we pass the following dummy dataframe to reduce the @@ -46,7 +47,7 @@ def __get_position_scales(metrics, chart_height, chart_width, concat_chart): # METRICS LABELS SCALE y_range = get_chart_size_range(chart_height, Metric_Chart.padding_y) - y_domain = [metric.upper() for metric in metrics] + y_domain = [display_name.get(metric.lower(), metric).upper() for metric in metrics] position_scales["y"] = alt.Scale(domain=y_domain, range=y_range) return position_scales @@ -55,7 +56,7 @@ def __get_position_scales(metrics, chart_height, chart_width, concat_chart): def __draw_metrics_rules(metrics, scales, concat_chart): """Draws an horizontal rule for each metric where the bubbles will be positioned.""" - metrics_labels = [metric.upper() for metric in metrics] + metrics_labels = [display_name.get(metric.lower(), metric).upper() for metric in metrics] if concat_chart: y_axis = no_axis() @@ -291,7 +292,7 @@ def __get_threshold_elements( lower_values.append(ref_group_value / fairness_threshold) # Convert to uppercase to match bubbles' Y axis - metrics_labels = [metric.upper() for metric in metrics] + metrics_labels = [display_name.get(metric.lower(), metric).upper() for metric in metrics] threshold_df = pd.DataFrame( { @@ -396,7 +397,7 @@ def __draw_bubbles( field=f"{metric}", type="quantitative", format=".2f", - title=f"{metric}".upper(), + title=display_name.get(metric.lower(), metric).upper(), ), ] @@ -405,7 +406,7 @@ def __draw_bubbles( bubble_centers += ( alt.Chart(plot_table) - .transform_calculate(metric_variable=f"'{metric.upper()}'") + .transform_calculate(metric_variable=f"'{display_name.get(metric.lower(), metric).upper()}'") .mark_point(filled=True, size=Bubble.center_size, cursor=Bubble.cursor) .encode( x=alt.X(f"{metric}:Q", scale=scales["x"], axis=x_axis), @@ -425,7 +426,7 @@ def __draw_bubbles( bubble_areas += ( alt.Chart(plot_table) .mark_circle(opacity=Bubble.opacity, cursor=Bubble.cursor) - .transform_calculate(metric_variable=f"'{metric.upper()}'") + .transform_calculate(metric_variable=f"'{display_name.get(metric.lower(), metric).upper()}'") .encode( x=alt.X(f"{metric}:Q", scale=scales["x"], axis=x_axis), y=alt.Y("metric_variable:N", scale=scales["y"], axis=no_axis()), @@ -548,6 +549,7 @@ def plot_metric_bubble_chart( :return: the full metrics chart :rtype: Altair chart object """ + metrics = [possible_metrics.get(metric.lower(), metric) for metric in metrics_list] ( plot_table, metrics, @@ -558,7 +560,7 @@ def plot_metric_bubble_chart( selection, ) = Initializer.prepare_bubble_chart( disparity_df, - metrics_list, + metrics, attribute, fairness_threshold, chart_height, diff --git a/src/aequitas/plot/commons/metrics.py b/src/aequitas/plot/commons/metrics.py new file mode 100644 index 00000000..ee3c2af8 --- /dev/null +++ b/src/aequitas/plot/commons/metrics.py @@ -0,0 +1,20 @@ +possible_metrics = { + "fpr": "fpr", + "fpr_ratio": "fpr", + "predictive equality": "fpr", + "predictive_equality": "fpr", + "tpr": "tpr", + "tpr_ratio": "tpr", + "equal opportunity": "tpr", + "equal_opportunity": "tpr", + "pprev": "pprev", + "pprev_ratio": "pprev", + "demographic parity": "pprev", + "demographic_parity": "pprev", +} + +display_name = { + "fpr": "Predictive Equality", + "tpr": "Equal Opportunity", + "pprev": "Demographic Parity", +} diff --git a/src/aequitas/plot/summary_chart.py b/src/aequitas/plot/summary_chart.py index e7617136..227052ac 100644 --- a/src/aequitas/plot/summary_chart.py +++ b/src/aequitas/plot/summary_chart.py @@ -24,6 +24,7 @@ from aequitas.plot.commons import initializers as Initializer from aequitas.plot.commons import validators as Validator from aequitas.plot.commons import labels as Label +from aequitas.plot.commons.metrics import possible_metrics, display_name # Altair 2.4.1 requires that all chart receive a dataframe, for charts that don't need # it (like most annotations), we pass the following dummy dataframe to reduce the @@ -586,6 +587,7 @@ def plot_summary_chart( :rtype: Altair chart object """ # If a specific list of attributes was not passed, use all from df + metrics = [possible_metrics.get(metric.lower(), metric) for metric in metrics_list] ( metrics, attributes, @@ -593,7 +595,7 @@ def plot_summary_chart( chart_width, ) = Initializer.prepare_summary_chart( disparity_df, - metrics_list, + metrics, attributes_list, fairness_threshold, chart_height, @@ -619,8 +621,10 @@ def plot_summary_chart( # SCALES scales = __get_scales(max_num_groups) + display_metrics = [display_name.get(metric, metric) for metric in metrics] + # METRIC TITLES - metric_titles = __draw_metric_line_titles(metrics, size_constants) + metric_titles = __draw_metric_line_titles(display_metrics, size_constants) # RELEVANT FIELDS viz_fields = [ From e5f2e7a149f062275b318ff990ca3c78ac361e02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Jesus?= Date: Fri, 2 Feb 2024 10:29:05 +0000 Subject: [PATCH 5/7] Remove dictionary from bubble concatenation chart --- src/aequitas/plot/bubble_concatenation_chart.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/aequitas/plot/bubble_concatenation_chart.py b/src/aequitas/plot/bubble_concatenation_chart.py index e62b3ad3..f7abe380 100644 --- a/src/aequitas/plot/bubble_concatenation_chart.py +++ b/src/aequitas/plot/bubble_concatenation_chart.py @@ -17,17 +17,6 @@ # complexity of the resulting vega spec. DUMMY_DF = pd.DataFrame({"a": [1, 1], "b": [0, 0]}) -metric_names = { - "Predictive Equality": "fpr_ratio", - "Equal Opportunity": "tpr_ratio", - "Demographic Parity": "pprev_ratio", - "TPR": "tpr", - "FPR": "fpr", - "FNR": "fnr", - "Accuracy": "accuracy", - "Precision": "precision", -} - def __get_chart_sizes(chart_width): """Calculates the widths of the disparity and metric charts that make-up the From 3479f4f06fc539d6d4945d17a6a88beb58ce509b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Jesus?= Date: Fri, 2 Feb 2024 11:40:45 +0000 Subject: [PATCH 6/7] Change reference to aequitas flow --- src/aequitas/flow/utils/colab.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aequitas/flow/utils/colab.py b/src/aequitas/flow/utils/colab.py index 89b894df..9c9c0843 100644 --- a/src/aequitas/flow/utils/colab.py +++ b/src/aequitas/flow/utils/colab.py @@ -14,7 +14,7 @@ def get_examples( "methods/data_repair", ] ) -> None: - """Downloads the examples from the fairflow repository. + """Downloads the examples from the aequitas flow repository. Note that this should not be used outside Google Colab, as it clutters the directory with with the git files from Aequitas repository. @@ -22,11 +22,11 @@ def get_examples( Parameters ---------- directory : Literal["configs", "examples/data_repair", "experiment_results"] - The directory to download from the fairflow repository. + The directory to download from the aequitas flow repository. """ directory = "examples/" + directory logger = create_logger("utils.colab") - logger.info("Downloading examples from fairflow repository.") + logger.info("Downloading examples from aequitas flow repository.") # Create directory if it doesn't exist Path(directory).mkdir(parents=True, exist_ok=True) # Check if git repository already exists in folder From f717c3976aad8ed63208697b8c84b3cd32615100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Jesus?= Date: Mon, 5 Feb 2024 09:52:25 +0000 Subject: [PATCH 7/7] Remove duplicated argument --- src/aequitas/plot/summary_chart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aequitas/plot/summary_chart.py b/src/aequitas/plot/summary_chart.py index 227052ac..dc172d4e 100644 --- a/src/aequitas/plot/summary_chart.py +++ b/src/aequitas/plot/summary_chart.py @@ -216,7 +216,6 @@ def __draw_parity_result_text(parity_result, color_scale): "parity_result:O", scale=color_scale, legend=alt.Legend( - title=Label.TEST, title=Label.TEST, padding=Legend.margin, offset=0,