Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] JAXPolicy (working discrete-actions PPO prototype). #13014

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ py_test(
)

# PPO
py_test(
name = "run_regression_tests_cartpole_ppo_jax",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/cartpole-ppo.yaml"],
args = ["--yaml-dir=tuned_examples/ppo", "--framework=jax"]
)

py_test(
name = "run_regression_tests_cartpole_ppo_tf",
main = "tests/run_regression_tests.py",
Expand Down Expand Up @@ -1089,6 +1099,13 @@ py_test(
srcs = ["models/tests/test_distributions.py"]
)

py_test(
name = "test_jax_models",
tags = ["models"],
size = "small",
srcs = ["models/tests/test_jax_models.py"]
)

py_test(
name = "test_preprocessors",
tags = ["models"],
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/ddpg/ddpg_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.framework import get_activation_fn

tf1, tf, tfv = try_import_tf()

Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/ppo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo_jax_policy import PPOJAXPolicy
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.agents.ppo.appo import APPOTrainer
Expand All @@ -8,6 +9,7 @@
"APPOTrainer",
"DDPPOTrainer",
"DEFAULT_CONFIG",
"PPOJAXPolicy",
"PPOTFPolicy",
"PPOTorchPolicy",
"PPOTrainer",
Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def validate_config(config: TrainerConfigDict) -> None:
"trajectory). Consider setting batch_mode=complete_episodes.")

# Multi-gpu not supported for PyTorch and tf-eager.
if config["framework"] in ["tf2", "tfe", "torch"]:
if config["framework"] != "tf":
config["simple_optimizer"] = True
# Performance warning, if "simple" optimizer used with (static-graph) tf.
elif config["simple_optimizer"]:
Expand Down Expand Up @@ -159,6 +159,9 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
return PPOTorchPolicy
elif config["framework"] == "jax":
from ray.rllib.agents.ppo.ppo_jax_policy import PPOJAXPolicy
return PPOJAXPolicy


class UpdateKL:
Expand Down
191 changes: 191 additions & 0 deletions rllib/agents/ppo/ppo_jax_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
JAX policy class used for PPO.
"""
import gym
import logging
from typing import List, Type, Union

import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
KLCoeffMixin, kl_and_loss_stats
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.jax_policy import LearningRateSchedule
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule
from ray.rllib.utils.framework import try_import_jax
from ray.rllib.utils.jax_ops import explained_variance
from ray.rllib.utils.typing import TensorType, TrainerConfigDict

jax, flax = try_import_jax()
jnp = None
if jax:
import jax.numpy as jnp

logger = logging.getLogger(__name__)


def ppo_surrogate_loss(
policy: Policy,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
vars=None,
) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.

Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.

Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
if vars:
for k, v in vars.items():
setattr(model, k, v)

logits, state = model.from_batch(train_batch, is_training=True)
curr_action_dist = dist_class(logits, model)

prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
model)

logp_ratio = jnp.exp(
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
train_batch[SampleBatch.ACTION_LOGP])
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl = jnp.mean(action_kl)

curr_entropy = curr_action_dist.entropy()
mean_entropy = jnp.mean(curr_entropy)

surrogate_loss = jnp.minimum(
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
train_batch[Postprocessing.ADVANTAGES] * jnp.clip(
logp_ratio, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))
mean_policy_loss = jnp.mean(-surrogate_loss)

if policy.config["use_gae"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = jnp.square(value_fn_out -
train_batch[Postprocessing.VALUE_TARGETS])
vf_clipped = prev_value_fn_out + jnp.clip(
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
policy.config["vf_clip_param"])
vf_loss2 = jnp.square(vf_clipped -
train_batch[Postprocessing.VALUE_TARGETS])
vf_loss = jnp.maximum(vf_loss1, vf_loss2)
mean_vf_loss = jnp.mean(vf_loss)
total_loss = jnp.mean(-surrogate_loss + policy.kl_coeff * action_kl +
policy.config["vf_loss_coeff"] * vf_loss -
policy.entropy_coeff * curr_entropy)
else:
mean_vf_loss = 0.0
total_loss = jnp.mean(-surrogate_loss + policy.kl_coeff * action_kl -
policy.entropy_coeff * curr_entropy)

# Store stats in policy for stats_fn.
policy._total_loss = total_loss
policy._mean_policy_loss = mean_policy_loss
policy._mean_vf_loss = mean_vf_loss
policy._vf_explained_var = explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function())
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl

if vars:
policy._total_loss = policy._total_loss.primal
policy._mean_policy_loss = policy._mean_policy_loss.primal
policy._mean_vf_loss = policy._mean_vf_loss.primal
policy._vf_explained_var = policy._vf_explained_var.primal
policy._mean_entropy = policy._mean_entropy.primal
policy._mean_kl = policy._mean_kl.primal

return total_loss


class ValueNetworkMixin:
"""Assigns the `_value()` method to the PPOPolicy.

This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""

def __init__(self, obs_space, action_space, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:

# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
assert config["_use_trajectory_view_api"]

def value(**input_dict):
model_out, _ = self.model.from_batch(
input_dict, is_training=False)
# [0] = remove the batch dim.
return self.model.value_function()[0]

# When not doing GAE, we do not require the value function's output.
else:

def value(*args, **kwargs):
return 0.0

self._value = value


def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Call all mixin classes' constructors before PPOPolicy initialization.

Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
KLCoeffMixin.__init__(policy, config)
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])


# Build a child class of `JAXPolicy`, given the custom functions defined
# above.
PPOJAXPolicy = build_policy_class(
name="PPOJAXPolicy",
framework="jax",
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
extra_action_out_fn=vf_preds_fetches,
postprocess_fn=postprocess_ppo_gae,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
before_loss_init=setup_mixins,
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
],
)
25 changes: 20 additions & 5 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def _check_lr_torch(policy, policy_id):
for p in opt.param_groups:
assert p["lr"] == policy.cur_lr, "LR scheduling error!"

@staticmethod
def _check_lr_jax(policy, policy_id):
for j, opt in enumerate(policy._optimizers):
assert opt.optimizer_def.hyper_params.learning_rate == \
policy.cur_lr, "LR scheduling error!"

@staticmethod
def _check_lr_tf(policy, policy_id):
lr = policy.cur_lr
Expand All @@ -57,8 +63,10 @@ def _check_lr_tf(policy, policy_id):
assert lr == optim_lr, "LR scheduling error!"

def on_train_result(self, *, trainer, result: dict, **kwargs):
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
"framework"] == "torch" else self._check_lr_tf)
fw = trainer.config["framework"]
trainer.workers.foreach_policy(
self._check_lr_tf if fw.startswith("tf") else self._check_lr_torch
if fw == "torch" else self._check_lr_jax)


class TestPPO(unittest.TestCase):
Expand All @@ -84,10 +92,17 @@ def test_ppo_compilation_and_lr_schedule(self):
config["train_batch_size"] = 128
num_iterations = 2

for _ in framework_iterator(config):
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
for fw in framework_iterator(
config, frameworks=("jax", "tf2", "tf", "torch")):
envs = ["CartPole-v0"]
if fw != "jax":
envs.append("MsPacmanNoFrameskip-v4")
for env in envs:
print("Env={}".format(env))
for lstm in [True, False]:
lstms = [False]
if fw != "jax":
lstms.append(True)
for lstm in lstms:
print("LSTM={}".format(lstm))
config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action"] = lstm
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/sac/sac_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.typing import ModelConfigDict, TensorType

Expand Down
Loading