-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhypersearch.py
30 lines (27 loc) · 924 Bytes
/
hypersearch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from hyperparams import HyperParams, hyper_search, make_hyper_range
from ppo import PPO
import torch.multiprocessing as mp
if __name__ == "__main__":
mp.set_start_method('forkserver')
ppo_trainer = PPO()
hyps = dict()
hyp_ranges = {
"lr": [9.5e-5, 1e-4, 1.5e-4],
"val_coef": [.005, .0075, .01],
}
keys = list(hyp_ranges.keys())
hyps['lambda_'] = .93
hyps['gamma'] = .985
hyps['entr_coef'] = .008
hyps['env_type'] = "Breakout-v0"
hyps['exp_name'] = "brkout2"
hyps['n_tsteps'] = 256
hyps['n_rollouts'] = 11
hyps['n_envs'] = 11
hyps['max_tsteps'] = 5000000
hyps['n_frame_stack'] = 3
search_log = open(hyps['exp_name']+"_searchlog.txt", 'w')
hyper_params = HyperParams(hyps)
hyps = hyper_params.hyps
hyper_search(hyper_params.hyps, hyp_ranges, keys, 0, ppo_trainer, search_log)
search_log.close()