Skip to content

Commit

Permalink
[rllib] Improve accessing model state docs (#5656)
Browse files Browse the repository at this point in the history
* [rllib] better model docs

* fix

* s
  • Loading branch information
ericl authored Sep 9, 2019
1 parent 87adb5a commit 74abeab
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 32 deletions.
91 changes: 87 additions & 4 deletions doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,98 @@ Accessing Model State

Similar to accessing policy state, you may want to get a reference to the underlying neural network model being trained. For example, you may want to pre-train it separately, or otherwise update its weights outside of RLlib. This can be done by accessing the ``model`` of the policy:

**Example: Preprocessing observations for feeding into a model**

.. code-block:: python
>>> import gym
>>> env = gym.make("Pong-v0")
# RLlib uses preprocessors to implement transforms such as one-hot encoding
# and flattening of tuple and dict observations.
>>> from ray.rllib.models.preprocessors import get_preprocessor
>>> prep = get_preprocessor(env.observation_space)(env.observation_space)
<ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>
# Observations should be preprocessed prior to feeding into a model
>>> env.reset().shape
(210, 160, 3)
>>> prep.transform(env.reset()).shape
(84, 84, 3)
**Example: Querying a policy's action distribution**

.. code-block:: python
# Get a reference to the policy
>>> from ray.rllib.agents.ppo import PPOTrainer
>>> trainer = PPOTrainer(env="CartPole-v0", config={"eager": True, "num_workers": 0})
>>> policy = trainer.get_policy()
<ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
>>> logits, _ = policy.model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
(<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])
# Compute action distribution given logits
>>> policy.dist_class
<class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
>>> dist = policy.dist_class(logits, policy.model)
<ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>
# Query the distribution for samples, sample logps
>>> dist.sample()
<tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
>>> dist.logp([1])
<tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>
# Get the estimated values for the most recent forward pass
>>> policy.model.value_function()
<tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>
>>> policy.model.base_model.summary()
Model: "model"
_____________________________________________________________________
Layer (type) Output Shape Param # Connected to
=====================================================================
observations (InputLayer) [(None, 4)] 0
_____________________________________________________________________
fc_1 (Dense) (None, 256) 1280 observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense) (None, 256) 1280 observations[0][0]
_____________________________________________________________________
fc_2 (Dense) (None, 256) 65792 fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense) (None, 2) 514 fc_2[0][0]
_____________________________________________________________________
value_out (Dense) (None, 1) 257 fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0
_____________________________________________________________________
**Example: Getting Q values from a DQN model**

.. code-block:: python
# Get a reference to the model through the policy
>>> from ray.rllib.agents.dqn import DQNTrainer
>>> trainer = DQNTrainer(env="CartPole-v0")
>>> trainer = DQNTrainer(env="CartPole-v0", config={"eager": True})
>>> model = trainer.get_policy().model
<ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
# List of all model variables
>>> model.variables()
[<tf.Variable 'default_policy/fc_1/kernel:0' shape=(4, 256) dtype=float32>, ...]
# Run a forward pass to get logits, can run with policy.get_session()
>>> model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
(<tf.Tensor 'model_3/fc_out/Tanh:0' shape=(1, 256) dtype=float32>, [])
# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is examples/saving_experiences.py
>>> model_out = model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
(<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)
# Access the base Keras models (all default models have a base)
>>> model.base_model.summary()
Expand All @@ -243,6 +320,9 @@ Similar to accessing policy state, you may want to get a reference to the underl
______________________________________________________________________________
# Access the Q value model (specific to DQN)
>>> model.get_q_value_distributions(model_out)
[<tf.Tensor: id=891, shape=(1, 2)>, <tf.Tensor: id=896, shape=(1, 2, 1)>]
>>> model.q_value_head.summary()
Model: "model_1"
_________________________________________________________________
Expand All @@ -258,6 +338,9 @@ Similar to accessing policy state, you may want to get a reference to the underl
_________________________________________________________________
# Access the state value model (specific to DQN)
>>> model.get_state_value(model_out)
<tf.Tensor: id=913, shape=(1, 1), dtype=float32>
>>> model.state_value_head.summary()
Model: "model_2"
_________________________________________________________________
Expand Down
8 changes: 6 additions & 2 deletions rllib/models/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ class Preprocessor(object):
def __init__(self, obs_space, options=None):
legacy_patch_shapes(obs_space)
self._obs_space = obs_space
self._options = options or {}
self.shape = self._init_shape(obs_space, options)
if not options:
from ray.rllib.models.catalog import MODEL_DEFAULTS
self._options = MODEL_DEFAULTS.copy()
else:
self._options = options
self.shape = self._init_shape(obs_space, self._options)
self._size = int(np.product(self.shape))
self._i = 0

Expand Down
19 changes: 11 additions & 8 deletions rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class DynamicTFPolicy(TFPolicy):
placeholders.
Initialization defines the static graph.
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
config (dict): config of the policy
model (TorchModel): TF model instance
dist_class (type): TF action distribution class
"""

def __init__(self,
Expand Down Expand Up @@ -78,10 +85,6 @@ def __init__(self,
the divisibility requirement for sample batches
obs_include_prev_action_reward (bool): whether to include the
previous action and reward in the model input
Attributes:
config: config of the policy
model: model instance, if any
"""
self.config = config
self._loss_fn = loss_fn
Expand Down Expand Up @@ -122,9 +125,9 @@ def __init__(self,
if not make_model:
raise ValueError(
"make_model is required if action_sampler_fn is given")
self._dist_class = None
self.dist_class = None
else:
self._dist_class, logit_dim = ModelCatalog.get_action_dist(
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])

if existing_model:
Expand Down Expand Up @@ -161,7 +164,7 @@ def __init__(self,
self, self.model, self._input_dict, obs_space, action_space,
config)
else:
action_dist = self._dist_class(model_out, self.model)
action_dist = self.dist_class(model_out, self.model)
action_sampler = action_dist.sample()
action_logp = action_dist.sampled_action_logp()

Expand Down Expand Up @@ -346,7 +349,7 @@ def fake_array(tensor):
self._sess.run(tf.global_variables_initializer())

def _do_loss_init(self, train_batch):
loss = self._loss_fn(self, self.model, self._dist_class, train_batch)
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
if self._stats_fn:
self._stats_fetches.update(self._stats_fn(self, train_batch))
# override the update ops to be those of the model
Expand Down
12 changes: 6 additions & 6 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def __init__(self, observation_space, action_space, config):
if not make_model:
raise ValueError(
"make_model is required if action_sampler_fn is given")
self._dist_class = None
self.dist_class = None
else:
self._dist_class, logit_dim = ModelCatalog.get_action_dist(
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])

if make_model:
Expand Down Expand Up @@ -176,8 +176,8 @@ def compute_actions(self,
model_out, state_out = self.model(
self._input_dict, state_batches, self._seq_lens)

if self._dist_class:
action_dist = self._dist_class(model_out, self.model)
if self.dist_class:
action_dist = self.dist_class(model_out, self.model)
action = action_dist.sample().numpy()
logp = action_dist.sampled_action_logp()
else:
Expand Down Expand Up @@ -252,7 +252,7 @@ def _compute_gradients(self, samples):
self._state_in = []
model_out, _ = self.model(samples, self._state_in,
self._seq_lens)
loss = loss_fn(self, self.model, self._dist_class, samples)
loss = loss_fn(self, self.model, self.dist_class, samples)

variables = self.model.trainable_variables()

Expand Down Expand Up @@ -369,7 +369,7 @@ def tile_to(tensor, n):
for k, v in postprocessed_batch.items()
}

loss_fn(self, self.model, self._dist_class, postprocessed_batch)
loss_fn(self, self.model, self.dist_class, postprocessed_batch)
if stats_fn:
stats_fn(self, postprocessed_batch)

Expand Down
27 changes: 15 additions & 12 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class TorchPolicy(Policy):
action_space (gym.Space): action space of the policy.
lock (Lock): Lock that must be held around PyTorch ops on this graph.
This is necessary when using the async sampler.
config (dict): config of the policy
model (TorchModel): Torch model instance
dist_class (type): Torch action distribution class
"""

def __init__(self, observation_space, action_space, model, loss,
Expand All @@ -53,10 +56,10 @@ def __init__(self, observation_space, action_space, model, loss,
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
self._model = model.to(self.device)
self.model = model.to(self.device)
self._loss = loss
self._optimizer = self.optimizer()
self._action_dist_class = action_distribution_class
self.dist_class = action_distribution_class

@override(Policy)
def compute_actions(self,
Expand All @@ -76,14 +79,14 @@ def compute_actions(self,
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch:
input_dict["prev_rewards"] = prev_reward_batch
model_out = self._model(input_dict, state_batches, [1])
model_out = self.model(input_dict, state_batches, [1])
logits, state = model_out
action_dist = self._action_dist_class(logits, self._model)
action_dist = self.dist_class(logits, self.model)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(input_dict, state_batches,
self._model))
self.model))

@override(Policy)
def learn_on_batch(self, postprocessed_batch):
Expand Down Expand Up @@ -117,7 +120,7 @@ def compute_gradients(self, postprocessed_batch):
# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self._model.parameters():
for p in self.model.parameters():
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
Expand All @@ -130,24 +133,24 @@ def compute_gradients(self, postprocessed_batch):
@override(Policy)
def apply_gradients(self, gradients):
with self.lock:
for g, p in zip(gradients, self._model.parameters()):
for g, p in zip(gradients, self.model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()

@override(Policy)
def get_weights(self):
with self.lock:
return {k: v.cpu() for k, v in self._model.state_dict().items()}
return {k: v.cpu() for k, v in self.model.state_dict().items()}

@override(Policy)
def set_weights(self, weights):
with self.lock:
self._model.load_state_dict(weights)
self.model.load_state_dict(weights)

@override(Policy)
def get_initial_state(self):
return [s.numpy() for s in self._model.get_initial_state()]
return [s.numpy() for s in self.model.get_initial_state()]

def extra_grad_process(self):
"""Allow subclass to do extra processing on gradients and
Expand All @@ -172,9 +175,9 @@ def optimizer(self):
"""Custom PyTorch optimizer to use."""
if hasattr(self, "config"):
return torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
self.model.parameters(), lr=self.config["lr"])
else:
return torch.optim.Adam(self._model.parameters())
return torch.optim.Adam(self.model.parameters())

def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
Expand Down

0 comments on commit 74abeab

Please sign in to comment.