Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Commit

Permalink
update plots
Browse files Browse the repository at this point in the history
  • Loading branch information
nkrusch committed Jan 14, 2024
1 parent b9bfa23 commit 08c66f7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions plot/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ def plot_acc(input_data, plot_name, data_labels,

class BarData(ResultData):

def get_acc_data(self, key_test):
nums = [BarData.fmt(r) for r in self.raw_rata if key_test(r)]
def get_acc_data(
self, key_test, key_label=None, att_label=None,
name_label=None):
nums = [BarData.fmt(r, key_label, att_label, name_label)
for r in self.raw_rata if
key_test(r)]
means = np.rint(np.mean(np.array(
[v for _, v in nums]), axis=0)).tolist()
cats = sorted(list(set([x for ((x, _, _), _) in nums])))
Expand All @@ -191,9 +195,10 @@ def fn_pattern(self, file_ext, pattern, out_dir=None, in_dirs=None):
return path.join(out_dir, f'{file_name}.{file_ext}')

@staticmethod
def fmt(r, key=None, att=None):
keys = (key or BarData.cls(r), BarData.name(r),
att or BarData.attack(r))
def fmt(r, key_label=None, att=None, cls=None):
keys = (key_label(r) if key_label else BarData.cls(r),
cls(r) if cls else BarData.name(r),
att(r) if att else BarData.attack(r))
valid = round(BarData.valid(r))
evades = round(BarData.evades(r)) - valid
accurate = round(BarData.acc(r)) - evades - valid
Expand All @@ -215,7 +220,7 @@ def match_bdata(x, y):
assert pair is not None


def attack_plot(bdata, out_dir, plot_name, dirs=None):
def attack_plot(bdata, out_dir, plot_name, dirs=None, comparison=False):
labels = ['valid', 'evasive', 'accurate', 'inaccurate']
key_test = lambda r: ResultData.attack(r) != 'CPGD'
bar_inputs = [d.get_acc_data(key_test) for d in bdata]
Expand All @@ -231,15 +236,19 @@ def attack_plot(bdata, out_dir, plot_name, dirs=None):
print("\n".join([f"{l:<13}: {x:.2f}" for (l, x) in
zip(labels, bar_inputs[-1][1])]))
print("=" * 40)
else:
key_test = lambda r: ResultData.attack(r) == 'CPGD'
bar_inputs = [d.get_acc_data(key_test) for d in bdata]
for b in bar_inputs[1:]:
match_bdata(bar_inputs[0], b)
if comparison:
args = lambda x: \
{'key_test': lambda r: ResultData.attack(r) == x,
'key_label': lambda _: x,
'att_label': lambda r: BarData.name(r),
'name_label': lambda _: ' '}
bar_inputs = [
bdata[0].get_acc_data(**args('CPGD')),
bdata[0].get_acc_data(**args('VPGD'))]
name = bdata[0].plot_name(
plot_name + '_cpgd', out_dir, dirs=dirs)
plot_acc(bar_inputs, overall_bar=True, data_labels=labels,
plot_name=name, sort_key=(lambda x: (x[0][0], x[0][1])))
plot_name=name, sort_key=(lambda x: x[0][1]))


def perf_plot(bdata, out_dir, plot_name):
Expand All @@ -264,4 +273,5 @@ def plot_bars(data_dir, out_dir=None):
else:
attack_plot(
bdata, out_dir, 'bar_acc',
dirs='_'.join(dirs) if len(dirs) > 1 else None)
dirs='_'.join(dirs) if len(dirs) > 1 else None,
comparison='attacks' in data_dir and len(dirs) == 1)
Binary file not shown.

0 comments on commit 08c66f7

Please sign in to comment.