-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplay.py
58 lines (44 loc) · 1.36 KB
/
replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import sys
import torch
from core.envs import make_vec_envs
from core.training.training_args import TrainingArgs
from core.utils.env_utils import get_render_func, get_vec_normalize
sys.path.append('core')
config_json = f'{os.path.dirname(__file__)}/results/lunarlander_v2/config.json'
args = TrainingArgs.from_json(config_json)
env = make_vec_envs(
args.env_name,
args.random_seed + 1000,
1,
None,
None,
device='cpu',
allow_early_resets=False
)
# get a render function
render_func = get_render_func(env)
# we need to use the same statistics for normalization as used in training
actor_critic, obs_rms = torch.load(
os.path.join(args.checkpoint_dir + "model.pt"),
map_location='cpu'
)
vec_norm = get_vec_normalize(env)
if vec_norm is not None:
vec_norm.eval()
vec_norm.obs_rms = obs_rms
recurrent_hidden_states = torch.zeros(1, actor_critic.recurrent_state_size)
masks = torch.zeros(1, 1)
obs = env.reset()
if render_func is not None:
render_func('human')
while True:
with torch.no_grad():
value, action, _, recurrent_hidden_states = actor_critic.act(
obs, recurrent_hidden_states, masks, deterministic=True
)
# Obser reward and next obs
obs, reward, done, _ = env.step(action)
masks.fill_(0.0 if done else 1.0)
if render_func is not None:
render_func('human')