-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
100 lines (83 loc) · 2.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
import os
import pprint
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb
import utilities
import utilities.utils as utils
import argparse
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
def seed_everything(seed):
print("Setting seed to {}".format(seed))
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def store_experiment_id(configs):
json.dump(
{"run_id": configs["wandb_run_id"]},
open(os.path.join(configs["checkpoint_path"], "id.json"), "w"),
)
def init_wandb(configs):
if configs["wandb_id_resume"] is None:
id = wandb.util.generate_id()
else:
id = configs["wandb_id_resume"]
wandb.init(
project=configs["wandb_project"],
entity=configs["wandb_entity"],
config=configs,
id=id,
resume="allow",
)
run = wandb.run
name = run.name
configs["wandb_run_name"] = name
configs["wandb_run_id"] = id
if "checkpoint_path" not in configs.keys():
checkpoint_path = utils.create_checkpoint_path(configs)
configs["checkpoint_path"] = checkpoint_path
store_experiment_id(configs)
return configs
def init_offline_experiment(configs):
configs["wandb_run_name"] = "offline"
configs["wandb_run_id"] = "None"
if "checkpoint_path" not in configs.keys():
checkpoint_path = utils.create_checkpoint_path(configs)
configs["checkpoint_path"] = checkpoint_path
return configs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None)
parser.add_argument("--training_config", default=None)
parser.add_argument("--dataset_config", default=None)
parser.add_argument("--method_config", default=None)
parser.add_argument("--augmentation_config", default=None)
parser.add_argument("--wandb_id_resume", default=None)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
# Setup configurations
configs = utils.load_configs(args)
pprint.pprint(configs)
if args.seed is not None:
configs["seed"] = args.seed
seed_everything(configs["seed"])
# Setup wandb
if configs["wandb"] and not configs["distributed"]:
configs = init_wandb(configs)
else:
configs = init_offline_experiment(configs)
trainer, tester = utils.create_procedures(configs)
if configs["phase"] == "train":
trainer(configs)
if tester is not None:
_, _, loader = utils.create_dataloaders(configs)
if isinstance(loader, list):
dataset_names = configs['dataset_names'].split(',')
for i, sub_loader in enumerate(loader):
tester(configs, loader=sub_loader, phase="test", dataset_name=dataset_names[i])
else:
tester(configs, loader=loader, phase="test")