-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrunexpwb.py
88 lines (69 loc) · 2.12 KB
/
runexpwb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
'''
Run experiment with wandb logging.
Usage:
python runexpwb.py --setting bag
Note: wandb isn't compatible with running scripts in subdirs:
e.g., python -m exps.chess.chessgfn
So we call wandb init here.
'''
import random
import torch
import wandb
import options
import numpy as np
from attrdict import AttrDict
from exps.tfbind8 import tfbind8_oracle
from exps.qm9str import qm9str
from exps.sehstr import sehstr
from exps.rna import rna
setting_calls = {
'tfbind8': lambda args: tfbind8_oracle.main(args),
'qm9str': lambda args: qm9str.main(args),
'sehstr': lambda args: sehstr.main(args),
'rna': lambda args: rna.main(args),
}
def main(args):
print(f'Using {args.setting} ...')
exp_f = setting_calls[args.setting]
exp_f(args)
return
def set_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if __name__ == '__main__':
args = options.parse_args()
set_seed(args.seed)
if args.setting == "rna":
args.saved_models_dir = f"{args.saved_models_dir}/L{args.rna_length}_RNA{args.rna_task}/"
wandb.init(project=f"{args.wandb_project}-L{args.rna_length}-{args.rna_task}",
entity=args.wandb_entity,
config=args,
mode=args.wandb_mode)
else:
wandb.init(project=args.wandb_project,
entity=args.wandb_entity,
config=args,
mode=args.wandb_mode)
args = AttrDict(wandb.config)
run_name = args.model
if args.model == 'subtb':
run_name += f"{args.lamda}"
if args.offline_select == "prt":
run_name += "_" + args.offline_select
if args.sa_or_ssr == "ssr":
run_name += "_" + args.sa_or_ssr
if args.ls:
run_name += "_" + "ls"
if args.deterministic:
run_name += "_" + "deterministic"
run_name += "_" + f"k{args.k}"
run_name += "_" + f"i{args.i}"
run_name += "_" + f"beta{args.beta}"
run_name += "_" + f"seed{args.seed}"
args.run_name = run_name.upper()
print(f"Save model into {args.run_name}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.device = device
main(args)