Skip to content

Commit

Permalink
Multiple small changes to the TF Privacy Report:
Browse files Browse the repository at this point in the history
     - Fix the legend to the bottom right
     - Manually set the size of the plot figure.
     - Fix a typo in the subplot title.

PiperOrigin-RevId: 337064528
  • Loading branch information
CdavM authored and tensorflower-gardener committed Oct 14, 2020
1 parent a8aa0d5 commit d1a8a6c
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ the model are used (e.g., losses, logits, predictions). Neither model internals

### Codelab

The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb).
The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb).
This trains a simple image classification model and tests it against a series
of membership inference attacks.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,8 @@ def _get_attack_results_filename(attack_results: AttackResults, index: int):
"""Creates a filename for a specific set of AttackResults."""
metadata = attack_results.privacy_report_metadata
if metadata is not None:
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
metadata.epoch_num, index)
return '%s_%s_epoch_%s.pickle' % (metadata.model_variant_label, index,
metadata.epoch_num)
return '%s.pickle' % index


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def main(unused_argv):
epoch_figure = privacy_report.plot_by_epochs(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
epoch_figure.show()
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
privacy_utility_figure.show()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@

def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
"""Plots privacy vulnerabilities vs epoch numbers.
In case multiple privacy metrics are specified, the plot will feature
multiple subplots (one subplot per metrics).
multiple subplots (one subplot per metrics). Multiple model variants
are supported.
Args:
results: AttackResults for the plot
privacy_metrics: List of enumerated privacy metrics that should be plotted.
Expand All @@ -54,12 +55,13 @@ def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics=privacy_metrics)


def plot_privacy_vs_accuracy_single_model(
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
def plot_privacy_vs_accuracy(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]):
"""Plots privacy vulnerabilities vs accuracy plots.
In case multiple privacy metrics are specified, the plot will feature
multiple subplots (one subplot per metrics).
multiple subplots (one subplot per metrics). Multiple model variants
are supported.
Args:
results: AttackResults for the plot
privacy_metrics: List of enumerated privacy metrics that should be plotted.
Expand Down Expand Up @@ -106,7 +108,8 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
figure_title: str,
privacy_metrics: Iterable[PrivacyMetric]):
"""Create one subplot per privacy metric for a specified x_axis_metric."""
fig, axes = plt.subplots(1, len(privacy_metrics))
fig, axes = plt.subplots(
1, len(privacy_metrics), figsize=(5 * len(privacy_metrics), 5))
# Set a title for the entire group of subplots.
fig.suptitle(figure_title)
if len(privacy_metrics) == 1:
Expand All @@ -116,11 +119,12 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
for legend_label in legend_labels:
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
== legend_label]
axes[i].plot(single_label_results[x_axis_metric],
single_label_results[str(privacy_metric)])
axes[i].legend(legend_labels)
sorted_label_results = single_label_results.sort_values(x_axis_metric)
axes[i].plot(sorted_label_results[x_axis_metric],
sorted_label_results[str(privacy_metric)])
axes[i].legend(legend_labels, loc='lower right')
axes[i].set_xlabel(x_axis_metric)
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR)
axes[i].set_title('%s for %s' % (privacy_metric, ENTIRE_DATASET_SLICE_STR))

return fig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def test_multiple_metrics_plot_by_epochs_multiple_models(self):
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
# Raise error if metadata is missing
self.assertRaises(
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
ValueError, privacy_report.plot_privacy_vs_accuracy,
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])

def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
fig = privacy_report.plot_privacy_vs_accuracy(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC'])
# extract data from figure.
Expand All @@ -158,7 +158,7 @@ def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')

def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
fig = privacy_report.plot_privacy_vs_accuracy(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC', 'Attacker advantage'])
# extract data from figure.
Expand All @@ -174,7 +174,7 @@ def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')

def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
fig = privacy_report.plot_privacy_vs_accuracy(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2)),
['AUC', 'Attacker advantage'])
Expand Down

0 comments on commit d1a8a6c

Please sign in to comment.