-
Notifications
You must be signed in to change notification settings - Fork 1
/
advanced_rnn.py
271 lines (230 loc) · 15.4 KB
/
advanced_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from __future__ import division
import warnings
from itertools import zip_longest
from stable_baselines.common.tf_util import batch_to_seq, seq_to_batch
from stable_baselines.common.tf_layers import linear, lstm
from stable_baselines.common.policies import RecurrentActorCriticPolicy,LstmPolicy
import numpy as np
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
def ortho_init(scale=1.0):
"""
Orthogonal initialization for the policy weights
:param scale: (float) Scaling factor for the weights.
:return: (function) an initialization function for the weights
"""
# _ortho_init(shape, dtype, partition_info=None)
def _ortho_init(shape, *_, **_kwargs):
"""Intialize weights as Orthogonal matrix.
Orthogonal matrix initialization [1]_. For n-dimensional shapes where
n > 2, the n-1 trailing axes are flattened. For convolutional layers, this
corresponds to the fan-in, so this makes the initialization usable for
both dense and convolutional layers.
References
----------
.. [1] Saxe, Andrew M., James L. McClelland, and Surya Ganguli.
"Exact solutions to the nonlinear dynamics of learning in deep
linear
"""
# lasagne ortho init for tf
shape = tuple(shape)
if len(shape) == 2:
flat_shape = shape
elif len(shape) == 4: # assumes NHWC
flat_shape = (np.prod(shape[:-1]), shape[-1])
else:
raise NotImplementedError
gaussian_noise = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(gaussian_noise, full_matrices=False)
weights = u if u.shape == flat_shape else v # pick the one with the correct shape
weights = weights.reshape(shape)
return (scale * weights[:shape[0], :shape[1]]).astype(np.float32)
return _ortho_init
def mlp_extractor(flat_observations, net_arch, act_fun):
"""
Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
structure:
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
If the number of ints is zero, there will be no shared layers.
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
would be specified as [128, 128].
:param flat_observations: (tf.Tensor) The observations to base policy and value function on.
:param net_arch: ([int or dict]) The specification of the policy and value networks.
See above for details on its formatting.
:param act_fun: (tf function) The activation function to use for the networks.
:return: (tf.Tensor, tf.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
latent = flat_observations
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
# Iterate through the shared layers and build the shared parts of the network
for idx, layer in enumerate(net_arch):
if isinstance(layer, int): # Check that this is a shared layer
layer_size = layer
latent = act_fun(linear(latent, "shared_fc{}".format(idx), layer_size, init_scale=np.sqrt(2)))
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
if 'pi' in layer:
assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
policy_only_layers = layer['pi']
if 'vf' in layer:
assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
value_only_layers = layer['vf']
break # From here on the network splits up in policy and value network
# Build the non-shared part of the network
latent_policy = latent
latent_value = latent
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
if pi_layer_size is not None:
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
latent_policy = act_fun(linear(latent_policy, "pi_fc{}".format(idx), pi_layer_size, init_scale=np.sqrt(2)))
if vf_layer_size is not None:
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
latent_value = act_fun(linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2)))
return latent_policy, latent_value
class AuxLstmPolicy(RecurrentActorCriticPolicy):
"""
Policy object that implements actor critic, using LSTMs.
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param layers: ([int]) The size of the Neural network before the LSTM layer (if None, default to [64, 64])
:param net_arch: (list) Specification of the actor-critic policy network architecture. Notation similar to the
format described in mlp_extractor but with additional support for a 'lstm' entry in the shared network part.
:param act_fun: (tf.func) the activation function to use in the neural network.
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param layer_norm: (bool) Whether or not to use layer normalizing LSTMs
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
recurrent = True
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
net_arch=None, act_fun=tf.tanh, cnn_extractor=None, layer_norm=False, feature_extraction="mlp",
**kwargs):
# state_shape = [n_lstm * 2] dim because of the cell and hidden states of the LSTM
super(AuxLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
state_shape=(2 * n_lstm, ), reuse=reuse,
scale=(feature_extraction == "cnn"))
self._kwargs_check(feature_extraction, kwargs)
if net_arch is None: # Legacy mode
if layers is None:
layers = [64, 64]
else:
warnings.warn("The layers parameter is deprecated. Use the net_arch parameter instead.")
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
extracted_features = tf.layers.flatten(self.processed_obs)
for i, layer_size in enumerate(layers):
extracted_features = act_fun(linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.dones_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
rnn_output = seq_to_batch(rnn_output)
value_fn = linear(rnn_output, 'vf', 1)
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output)
self._value_fn = value_fn
else: # Use the new net_arch parameter
if layers is not None:
warnings.warn("The new net_arch parameter overrides the deprecated layers parameter.")
if feature_extraction == "cnn":
raise NotImplementedError()
with tf.variable_scope("model", reuse=reuse):
latent = tf.layers.flatten(self.processed_obs)
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
# Iterate through the shared layers and build the shared parts of the network
lstm_layer_constructed = False
for idx, layer in enumerate(net_arch):
if isinstance(layer, int): # Check that this is a shared layer
layer_size = layer
latent = act_fun(linear(latent, "shared_fc{}".format(idx), layer_size, init_scale=np.sqrt(2)))
elif layer == "lstm":
if lstm_layer_constructed:
raise ValueError("The net_arch parameter must only contain one occurrence of 'lstm'!")
input_sequence = batch_to_seq(latent, self.n_env, n_steps)
masks = batch_to_seq(self.dones_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
latent = seq_to_batch(rnn_output)
lstm_layer_constructed = True
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
if 'pi' in layer:
assert isinstance(layer['pi'],
list), "Error: net_arch[-1]['pi'] must contain a list of integers."
policy_only_layers = layer['pi']
if 'vf' in layer:
assert isinstance(layer['vf'],
list), "Error: net_arch[-1]['vf'] must contain a list of integers."
value_only_layers = layer['vf']
break # From here on the network splits up in policy and value network
# Build the non-shared part of the policy-network
latent_policy = latent
for idx, pi_layer_size in enumerate(policy_only_layers):
if pi_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the policy network.")
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
latent_policy = act_fun(
linear(latent_policy, "pi_fc{}".format(idx), pi_layer_size, init_scale=np.sqrt(2)))
# Build the non-shared part of the value-network
latent_value = latent
for idx, vf_layer_size in enumerate(value_only_layers):
if vf_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the value function "
"network.")
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
latent_value = act_fun(
linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2)))
if not lstm_layer_constructed:
raise ValueError("The net_arch parameter must contain at least one occurrence of 'lstm'!")
self._value_fn = linear(latent_value, 'vf', 1)
self.pred_pred = tf.math.sigmoid(linear(latent,'pred',1))
# TODO: why not init_scale = 0.001 here like in the feedforward
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(latent_policy, latent_value)
self._setup_init()
def step(self, obs, state=None, mask=None, deterministic=False):
# print(f"shape of the state is {state.shape}")
if deterministic:
return self.sess.run([(self.deterministic_action,self.pred_pred), self.value_flat, self.snew, self.neglogp], #*
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
else:
return self.sess.run([(self.deterministic_action,self.pred_pred), self.value_flat, self.snew, self.neglogp], #*
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
def value(self, obs, state=None, mask=None):
return self.sess.run(self.value_flat, {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
class AuxLstmPolicyv2(LstmPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
net_arch=None, act_fun=tf.tanh, cnn_extractor=None, layer_norm=False, feature_extraction="mlp",
**kwargs):
super(AuxLstmPolicyv2,self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=n_lstm, reuse=reuse, layers=layers,
net_arch=net_arch, act_fun=act_fun, cnn_extractor=cnn_extractor, layer_norm=layer_norm, feature_extraction=feature_extraction,
**kwargs)
def step(self, obs, state=None, mask=None, deterministic=False):
# print(f"shape of the state is {state.shape}")
if deterministic:
return self.sess.run([(self.deterministic_action,self.pred_pred), self.value_flat, self.snew, self.neglogp], #*
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
else:
return self.sess.run([(self.deterministic_action,self.pred_pred), self.value_flat, self.snew, self.neglogp], #*
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})