-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathbrew_poison.py
executable file
·77 lines (61 loc) · 2.91 KB
/
brew_poison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""General interface script to launch poisoning jobs."""
import torch
import datetime
import time
import forest
torch.backends.cudnn.benchmark = forest.consts.BENCHMARK
torch.multiprocessing.set_sharing_strategy(forest.consts.SHARING_STRATEGY)
# Parse input arguments
args = forest.options().parse_args()
# 100% reproducibility?
if args.deterministic:
forest.utils.set_deterministic()
if __name__ == "__main__":
setup = forest.utils.system_startup(args)
model = forest.Victim(args, setup=setup)
data = forest.Kettle(args, model.defs.batch_size, model.defs.augmentations, setup=setup)
witch = forest.Witch(args, setup=setup)
start_time = time.time()
if args.pretrained:
print('Loading pretrained model...')
stats_clean = None
else:
stats_clean = model.train(data, max_epoch=args.max_epoch)
train_time = time.time()
poison_delta = witch.brew(model, data)
brew_time = time.time()
if not args.pretrained and args.retrain_from_init:
stats_rerun = model.retrain(data, poison_delta)
else:
stats_rerun = None # we dont know the initial seed for a pretrained model so retraining makes no sense
if args.vnet is not None: # Validate the transfer model given by args.vnet
train_net = args.net
args.net = args.vnet
if args.vruns > 0:
model = forest.Victim(args, setup=setup)
stats_results = model.validate(data, poison_delta)
else:
stats_results = None
args.net = train_net
else: # Validate the main model
if args.vruns > 0:
stats_results = model.validate(data, poison_delta)
else:
stats_results = None
test_time = time.time()
timestamps = dict(train_time=str(datetime.timedelta(seconds=train_time - start_time)).replace(',', ''),
brew_time=str(datetime.timedelta(seconds=brew_time - train_time)).replace(',', ''),
test_time=str(datetime.timedelta(seconds=test_time - brew_time)).replace(',', ''))
# Save run to table
results = (stats_clean, stats_rerun, stats_results)
forest.utils.record_results(data, witch.stat_optimal_loss, results,
args, model.defs, model.model_init_seed, extra_stats=timestamps)
# Export
if args.save is not None:
data.export_poison(poison_delta, path=args.poison_path, mode=args.save)
print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
print('---------------------------------------------------')
print(f'Finished computations with train time: {str(datetime.timedelta(seconds=train_time - start_time))}')
print(f'--------------------------- brew time: {str(datetime.timedelta(seconds=brew_time - train_time))}')
print(f'--------------------------- test time: {str(datetime.timedelta(seconds=test_time - brew_time))}')
print('-------------Job finished.-------------------------')