-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathtest_policy.py
77 lines (62 loc) · 2.67 KB
/
test_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gym
import numpy as np
import tensorflow as tf
import argparse
from network_models.policy_net import Policy_net
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument('--modeldir', help='directory of model', default='trained_models')
parser.add_argument('--alg', help='chose algorithm one of gail, ppo, bc', default='gail')
parser.add_argument('--model', help='number of model to test. model.ckpt-number', default='')
parser.add_argument('--logdir', help='log directory', default='log/test')
parser.add_argument('--iteration', default=int(1e3))
parser.add_argument('--stochastic', action='store_false')
return parser.parse_args()
def main(args):
env = gym.make('CartPole-v0')
env.seed(0)
Policy = Policy_net('policy', env)
saver = tf.train.Saver()
with tf.Session() as sess:
writer = tf.summary.FileWriter(args.logdir+'/'+args.alg, sess.graph)
sess.run(tf.global_variables_initializer())
if args.model == '':
saver.restore(sess, args.modeldir+'/'+args.alg+'/'+'model.ckpt')
else:
saver.restore(sess, args.modeldir+'/'+args.alg+'/'+'model.ckpt-'+args.model)
obs = env.reset()
reward = 0
success_num = 0
for iteration in range(args.iteration):
rewards = []
run_policy_steps = 0
while True: # run policy RUN_POLICY_STEPS which is much less than episode length
run_policy_steps += 1
obs = np.stack([obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs
act, _ = Policy.act(obs=obs, stochastic=args.stochastic)
act = np.asscalar(act)
rewards.append(reward)
next_obs, reward, done, info = env.step(act)
if done:
obs = env.reset()
reward = -1
break
else:
obs = next_obs
writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)])
, iteration)
writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))])
, iteration)
# end condition of test
if sum(rewards) >= 195:
success_num += 1
if success_num >= 100:
print('Iteration: ', iteration)
print('Clear!!')
break
else:
success_num = 0
writer.close()
if __name__ == '__main__':
args = argparser()
main(args)