-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainPPO.py
79 lines (65 loc) · 1.8 KB
/
trainPPO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gym
import time
import yaml
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback
from scripts.network import UAVPolicyCNN
# Load train environment configs
with open('scripts/env_config.yml', 'r') as f:
env_config = yaml.safe_load(f)
# Load inference configs
with open('config.yml', 'r') as f:
config = yaml.safe_load(f)
# Determine input image shape
image_shape = (50,50,1) if config["train_mode"]=="depth" else (50,50,3)
# Create a DummyVecEnv
env = DummyVecEnv([lambda: Monitor(
gym.make(
"scripts:airsim-env-v0",
ip_address="127.0.0.1",
image_shape=image_shape,
env_config=env_config["TrainEnv"],
input_mode=config["train_mode"]
)
)])
# Wrap env as VecTransposeImage (Channel last to channel first)
env = VecTransposeImage(env)
policy_kwargs = dict(
features_extractor_class=UAVPolicyCNN
)
model = PPO(
'CnnPolicy',
env,
# learning_rate=0.0001
batch_size=128,
clip_range=0.10,
max_grad_norm=0.5,
verbose=1,
seed=1,
device="cuda",
tensorboard_log="./tb_logs/",
policy_kwargs=policy_kwargs,
)
# model = PPO.load(path="best_model.zip", env=env)
# Evaluation callback
callbacks = []
eval_callback = EvalCallback(
env,
callback_on_new_best=None,
n_eval_episodes=20,
best_model_save_path="saved_policy",
log_path=".",
eval_freq=1024,
)
callbacks.append(eval_callback)
kwargs = {}
kwargs["callback"] = callbacks
log_name = "ppo_run_" + str(time.time())
model.learn(
total_timesteps=100000, #350000 takes 3 hours
tb_log_name=log_name,
**kwargs
)