Skip to content

Commit

Permalink
[RLlib] Document and extend action mask example. (#20390)
Browse files Browse the repository at this point in the history
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Sven Mika <sven@anyscale.io>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
  • Loading branch information
4 people authored Nov 16, 2021
1 parent 3e6ba5d commit 2b3d0c6
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 41 deletions.
178 changes: 141 additions & 37 deletions rllib/examples/action_masking.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,129 @@
"""Example showing how to use "action masking" in RLlib.
"Action masking" allows the agent to select actions based on the current
observation. This is useful in many practical scenarios, where different
actions are available in different time steps.
Blog post explaining action masking: https://boring-guy.sh/posts/masking-rl/
RLlib supports action masking, i.e., disallowing these actions based on the
observation, by slightly adjusting the environment and the model as shown in
this example.
Here, the ActionMaskEnv wraps an underlying environment (here, RandomEnv),
defining only a subset of all actions as valid based on the environment's
observations. If an invalid action is selected, the environment raises an error
- this must not happen!
The environment constructs Dict observations, where obs["observations"] holds
the original observations and obs["action_mask"] holds the valid actions.
To avoid selection invalid actions, the ActionMaskModel is used. This model
takes the original observations, computes the logits of the corresponding
actions and then sets the logits of all invalid actions to zero, thus disabling
them. This only works with discrete actions.
---
Run this example with defaults (using Tune and action masking):
$ python action_masking.py
Then run again without action masking, which will likely lead to errors due to
invalid actions being selected (ValueError "Invalid action sent to env!"):
$ python action_masking.py --no-masking
Other options for running this example:
$ python action_masking.py --help
"""

import argparse
from gym.spaces import Box, Discrete
import os

from gym.spaces import Box, Discrete
import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.action_mask_env import ActionMaskEnv
from ray.rllib.examples.models.action_mask_model import \
ActionMaskModel, TorchActionMaskModel
from ray.tune.logger import pretty_print

parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
type=str,
default="APPO",
help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--eager-tracing", action="store_true")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=100000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")

if __name__ == "__main__":
import ray
from ray import tune
def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()

# example-specific args
parser.add_argument(
"--no-masking",
action="store_true",
help="Do NOT mask invalid actions. This will likely lead to errors.")

# general args
parser.add_argument(
"--run",
type=str,
default="APPO",
help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--eager-tracing", action="store_true")
parser.add_argument(
"--stop-iters",
type=int,
default=10,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=10000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.")
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")

args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

# main part: configure the ActionMaskEnv and ActionMaskModel
config = {
# random env with 100 discrete actions and 5x [-1,1] observations
# some actions are declared invalid and lead to errors
"env": ActionMaskEnv,
"env_config": {
"action_space": Discrete(100),
"observation_space": Box(-1.0, 1.0, (5, )),
},
# the ActionMaskModel retrieves the invalid actions and avoids them
"model": {
"custom_model": ActionMaskModel
if args.framework != "torch" else TorchActionMaskModel,
# disable action masking according to CLI
"custom_model_config": {
"no_masking": args.no_masking
}
},

# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
Expand All @@ -71,6 +139,42 @@
"episode_reward_mean": args.stop_reward,
}

results = tune.run(args.run, config=config, stop=stop, verbose=2)
# manual training loop (no Ray tune)
if args.no_tune:
if args.run not in {"APPO", "PPO"}:
raise ValueError("This example only supports APPO and PPO.")
ppo_config = ppo.DEFAULT_CONFIG.copy()
ppo_config.update(config)
trainer = ppo.PPOTrainer(config=ppo_config, env=ActionMaskEnv)
# run manual training loop and print results after each iteration
for _ in range(args.stop_iters):
result = trainer.train()
print(pretty_print(result))
# stop training if the target train steps or reward are reached
if result["timesteps_total"] >= args.stop_timesteps or \
result["episode_reward_mean"] >= args.stop_reward:
break

# manual test loop
print("Finished training. Running manual test/inference loop.")
# prepare environment with max 10 steps
config["env_config"]["max_episode_len"] = 10
env = ActionMaskEnv(config["env_config"])
obs = env.reset()
done = False
# run one iteration until done
print(f"ActionMaskEnv with {config['env_config']}")
while not done:
action = trainer.compute_single_action(obs)
next_obs, reward, done, _ = env.step(action)
# observations contain original observations and the action mask
# reward is random and irrelevant here and therefore not printed
print(f"Obs: {obs}, Action: {action}")
obs = next_obs

# run with tune for auto trainer creation, stopping, TensorBoard, etc.
else:
results = tune.run(args.run, config=config, stop=stop, verbose=2)

print("Finished successfully without selecting invalid actions.")
ray.shutdown()
25 changes: 21 additions & 4 deletions rllib/examples/models/action_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,

orig_space = getattr(obs_space, "original_space", obs_space)
assert isinstance(orig_space, Dict) and \
"action_mask" in orig_space.spaces and \
"observations" in orig_space.spaces
"action_mask" in orig_space.spaces and \
"observations" in orig_space.spaces

super().__init__(obs_space, action_space, num_outputs, model_config,
name)
Expand All @@ -35,6 +35,10 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
orig_space["observations"], action_space, num_outputs,
model_config, name + "_internal")

# disable action masking --> will likely lead to invalid actions
self.no_masking = model_config["custom_model_config"].get(
"no_masking", False)

def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
Expand All @@ -44,6 +48,10 @@ def forward(self, input_dict, state, seq_lens):
"obs": input_dict["obs"]["observations"]
})

# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state

# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
masked_logits = logits + inf_mask
Expand All @@ -69,8 +77,8 @@ def __init__(
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert isinstance(orig_space, Dict) and \
"action_mask" in orig_space.spaces and \
"observations" in orig_space.spaces
"action_mask" in orig_space.spaces and \
"observations" in orig_space.spaces

TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name, **kwargs)
Expand All @@ -80,6 +88,11 @@ def __init__(
num_outputs, model_config,
name + "_internal")

# disable action masking --> will likely lead to invalid actions
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]

def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
Expand All @@ -89,6 +102,10 @@ def forward(self, input_dict, state, seq_lens):
"obs": input_dict["obs"]["observations"]
})

# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state

# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
Expand Down

0 comments on commit 2b3d0c6

Please sign in to comment.