Skip to content

Commit

Permalink
[RLlib] No Preprocessors; preparatory PR #1 (ray-project#18367)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Sep 9, 2021
1 parent 1520c3d commit 8a06647
Show file tree
Hide file tree
Showing 15 changed files with 268 additions and 96 deletions.
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2273,9 +2273,9 @@ py_test(
name = "examples/custom_observation_filters",
main = "examples/custom_observation_filters.py",
tags = ["team:ml", "examples", "examples_C"],
size = "small",
size = "medium",
srcs = ["examples/custom_observation_filters.py"],
args = ["--stop-iters=2"]
args = ["--stop-iters=3"]
)

py_test(
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def postprocess_nstep_and_prio(policy: Policy,
if policy.config["n_step"] > 1:
adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch)

# Create dummy prio-weights (1.0) in case we don't have any in
# the batch.
if PRIO_WEIGHTS not in batch:
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])

Expand Down
16 changes: 12 additions & 4 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,17 @@
# Tuple[value1, value2]: Clip at value1 and value2.
"clip_rewards": None,
# If True, RLlib will learn entirely inside a normalized action space
# (0.0 centered with small stddev; only affecting Box components) and
# only unsquash actions (and clip just in case) to the bounds of
# env's action space before sending actions back to the env.
# (0.0 centered with small stddev; only affecting Box components).
# We will unsquash actions (and clip, just in case) to the bounds of
# the env's action space before sending actions back to the env.
"normalize_actions": True,
# If True, RLlib will clip actions according to the env's bounds
# before sending them back to the env.
# TODO: (sven) This option should be obsoleted and always be False.
"clip_actions": False,
# Whether to use "rllib" or "deepmind" preprocessors by default
# Set to None for using no preprocessor. In this case, the model will have
# to handle possibly complex observations from the environment.
"preprocessor_pref": "deepmind",

# === Debug Settings ===
Expand Down Expand Up @@ -1041,7 +1043,7 @@ def compute_single_action(

# Check the preprocessor and preprocess, if necessary.
pp = local_worker.preprocessors[policy_id]
if type(pp).__name__ != "NoPreprocessor":
if pp and type(pp).__name__ != "NoPreprocessor":
observation = pp.transform(observation)
filtered_observation = local_worker.filters[policy_id](
observation, update=False)
Expand Down Expand Up @@ -1511,6 +1513,12 @@ def _validate_config(config: PartialTrainerConfigDict,
config["input_evaluation"]))

# Check model config.
# If no preprocessing, propagate into model's config as well
# (so model will know, whether inputs are preprocessed or not).
if config["preprocessor_pref"] is None:
model_config["_no_preprocessor"] = True

# Prev_a/r settings.
prev_a_r = model_config.get("lstm_use_prev_action_reward",
DEPRECATED_VALUE)
if prev_a_r != DEPRECATED_VALUE:
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def make_multi_agent(env_name_or_creator):
Returns:
Type[MultiAgentEnv]: New MultiAgentEnv class to be used as env.
The constructor takes a config dict with `num_agents` key
(default=1). The reset of the config dict will be passed on to the
(default=1). The rest of the config dict will be passed on to the
underlying single-agent env's constructor.
Examples:
Expand Down
13 changes: 8 additions & 5 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
count_steps_by: str = "env_steps",
batch_mode: str = "truncate_episodes",
episode_horizon: int = None,
preprocessor_pref: str = "deepmind",
preprocessor_pref: Optional[str] = "deepmind",
sample_async: bool = False,
compress_observations: bool = False,
num_envs: int = 1,
Expand Down Expand Up @@ -257,8 +257,9 @@ class to use.
until the episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (str): Whether to prefer RLlib preprocessors
("rllib") or deepmind ("deepmind") when applicable.
preprocessor_pref (Optional[str]): Whether to use no preprocessor
(None), RLlib preprocessors ("rllib") or deepmind ("deepmind"),
when applicable.
sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
Expand Down Expand Up @@ -419,7 +420,8 @@ def gen_rollouts():
self.count_steps_by: str = count_steps_by
self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = True
self.preprocessing_enabled: bool = False \
if preprocessor_pref is None else True
self.observation_filter = observation_filter
self.last_batch: SampleBatchType = None
self.global_vars: dict = None
Expand Down Expand Up @@ -1363,7 +1365,8 @@ def _build_policy_map(
preprocessor = ModelCatalog.get_preprocessor_for_space(
obs_space, merged_conf.get("model"))
self.preprocessors[name] = preprocessor
obs_space = preprocessor.observation_space
if preprocessor is not None:
obs_space = preprocessor.observation_space
else:
self.preprocessors[name] = NoPreprocessor(obs_space)

Expand Down
20 changes: 14 additions & 6 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,13 @@ def _process_observations(

policy_id: PolicyID = episode.policy_for(agent_id)

prep_obs: EnvObsType = _get_or_raise(worker.preprocessors,
policy_id).transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
preprocessor = _get_or_raise(worker.preprocessors, policy_id)
prep_obs: EnvObsType = raw_obs
if preprocessor is not None:
prep_obs = preprocessor.transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(
summarize(prep_obs)))
filtered_obs: EnvObsType = _get_or_raise(worker.filters,
policy_id)(prep_obs)
if log_once("filtered_obs"):
Expand Down Expand Up @@ -955,10 +958,15 @@ def _process_observations(
# types: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = new_episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(
worker.preprocessors, policy_id).transform(raw_obs)
preproccessor = _get_or_raise(worker.preprocessors,
policy_id)

prep_obs: EnvObsType = raw_obs
if preproccessor is not None:
prep_obs = preproccessor.transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(
worker.filters, policy_id)(prep_obs)
new_episode._set_last_raw_obs(agent_id, raw_obs)
new_episode._set_last_observation(agent_id, filtered_obs)

# Add initial obs to buffer.
Expand Down
5 changes: 1 addition & 4 deletions rllib/examples/custom_observation_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def __repr__(self):
}

results = tune.run(
"PG",
args.run,
config=config,
stop={"training_iteration": args.stop_iters})
args.run, config=config, stop={"training_iteration": args.stop_iters})

ray.shutdown()
36 changes: 26 additions & 10 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \
deprecation_warning
from ray.rllib.utils.error import UnsupportedSpaceException
Expand All @@ -44,6 +44,11 @@
# 2) fully connected and CNN default networks as well as
# auto-wrapped LSTM- and attention nets.
"_use_default_native_models": False,
# Experimental flag.
# If True, user specified no preprocessor to be created
# (via config.preprocessor_pref=None). If True, observations will arrive
# in model as they are returned by the env.
"_no_preprocessing": False,

# === Built-in options ===
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
Expand Down Expand Up @@ -693,12 +698,13 @@ def get_preprocessor_for_space(observation_space: gym.Space,
cls = get_preprocessor(observation_space)
prep = cls(observation_space, options)

logger.debug("Created preprocessor {}: {} -> {}".format(
prep, observation_space, prep.shape))
if prep is not None:
logger.debug("Created preprocessor {}: {} -> {}".format(
prep, observation_space, prep.shape))
return prep

@staticmethod
@PublicAPI
@Deprecated(error=False)
def register_custom_preprocessor(preprocessor_name: str,
preprocessor_class: type) -> None:
"""Register a custom preprocessor class by name.
Expand Down Expand Up @@ -796,14 +802,15 @@ def _get_v2_model_class(input_space: gym.Space,
"framework={} not supported in `ModelCatalog._get_v2_model_"
"class`!".format(framework))

# Tuple space, where at least one sub-space is image.
# -> Complex input model.
# Complex space, where at least one sub-space is image.
# -> Complex input model (which auto-flattens everything, but correctly
# processes image components with default CNN stacks).
space_to_check = input_space if not hasattr(
input_space, "original_space") else input_space.original_space
if isinstance(input_space,
Tuple) or (isinstance(space_to_check, Tuple) and any(
isinstance(s, Box) and len(s.shape) >= 2
for s in space_to_check.spaces)):
if isinstance(input_space, (Dict, Tuple)) or (isinstance(
space_to_check, (Dict, Tuple)) and any(
isinstance(s, Box) and len(s.shape) >= 2
for s in tree.flatten(space_to_check.spaces))):
return ComplexNet

# Single, flattenable/one-hot-able space -> Simple FCNet.
Expand Down Expand Up @@ -860,6 +867,15 @@ def _validate_config(config: ModelConfigDict, framework: str) -> None:
Raises:
ValueError: If something is wrong with the given config.
"""
# Soft-deprecate custom preprocessors.
if config.get("custom_preprocessor") is not None:
deprecation_warning(
old="model.custom_preprocessor",
new="gym.ObservationWrapper around your env or handle complex "
"inputs inside your Model",
error=False,
)

if config.get("use_attention") and config.get("use_lstm"):
raise ValueError("Only one of `use_lstm` or `use_attention` may "
"be set to True!")
Expand Down
32 changes: 23 additions & 9 deletions rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,32 @@ def __call__(
input_dict["is_training"] = input_dict.is_training
else:
restored = input_dict.copy()
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
try:
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"],
self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
except AttributeError:

# No Preprocessor used: `config.preprocessor_pref`=None.
# TODO: This is unnecessary for when no preprocessor is used.
# Obs are not flat then anymore. However, we'll keep this
# here for backward-compatibility until Preprocessors have
# been fully deprecated.
if self.model_config.get("_no_preprocessing"):
restored["obs_flat"] = input_dict["obs"]
# Input to this Model went through a Preprocessor.
# Generate extra keys: "obs_flat" (vs "obs", which will hold the
# original obs).
else:
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
try:
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"],
self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
except AttributeError:
restored["obs_flat"] = input_dict["obs"]

with self.context():
res = self.forward(restored, state or [], seq_lens)

if ((not isinstance(res, list) and not isinstance(res, tuple))
or len(res) != 2):
raise ValueError(
Expand Down
18 changes: 14 additions & 4 deletions rllib/models/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,14 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
for i in range(len(self._obs_space.spaces)):
space = self._obs_space.spaces[i]
logger.debug("Creating sub-preprocessor for {}".format(space))
preprocessor = get_preprocessor(space)(space, self._options)
preprocessor_class = get_preprocessor(space)
if preprocessor_class is not None:
preprocessor = preprocessor_class(space, self._options)
size += preprocessor.size
else:
preprocessor = None
size += int(np.product(space.shape))
self.preprocessors.append(preprocessor)
size += preprocessor.size
return (size, )

@override(Preprocessor)
Expand Down Expand Up @@ -247,9 +252,14 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
self.preprocessors = []
for space in self._obs_space.spaces.values():
logger.debug("Creating sub-preprocessor for {}".format(space))
preprocessor = get_preprocessor(space)(space, self._options)
preprocessor_class = get_preprocessor(space)
if preprocessor_class is not None:
preprocessor = preprocessor_class(space, self._options)
size += preprocessor.size
else:
preprocessor = None
size += int(np.product(space.shape))
self.preprocessors.append(preprocessor)
size += preprocessor.size
return (size, )

@override(Preprocessor)
Expand Down
29 changes: 19 additions & 10 deletions rllib/models/tf/complex_input_net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from gym.spaces import Box, Discrete, Tuple
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import tree # pip install dm_tree

from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
Expand All @@ -9,6 +10,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.tf_ops import one_hot

tf1, tf, tfv = try_import_tf()
Expand All @@ -31,21 +33,22 @@ class ComplexInputNetwork(TFModelV2):

def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
self.original_space = obs_space.original_space if \
hasattr(obs_space, "original_space") else obs_space
assert isinstance(self.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
assert isinstance(self.original_space, (Dict, Tuple)), \
"`obs_space.original_space` must be [Dict|Tuple]!"

super().__init__(self.original_space, action_space, num_outputs,
model_config, name)

self.flattened_input_space = flatten_space(self.original_space)

# Build the CNN(s) given obs_space's image components.
self.cnns = {}
self.one_hot = {}
self.flatten = {}
concat_size = 0
for i, component in enumerate(self.original_space):
for i, component in enumerate(self.flattened_input_space):
# Image space.
if len(component.shape) == 3:
config = {
Expand All @@ -64,11 +67,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
# Discrete inputs -> One-hot encode.
# Discrete|MultiDiscrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
self.one_hot[i] = True
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
elif isinstance(component, MultiDiscrete):
self.one_hot[i] = True
concat_size += sum(component.nvec)
# Everything else (1D Box).
else:
self.flatten[i] = int(np.product(component.shape))
Expand Down Expand Up @@ -123,18 +128,22 @@ def forward(self, input_dict, state, seq_lens):
self.obs_space, "tf")
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(orig_obs):
for i, component in enumerate(tree.flatten(orig_obs)):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
outs.append(cnn_out)
elif i in self.one_hot:
if component.dtype in [tf.int32, tf.int64, tf.uint8]:
outs.append(
one_hot(component, self.original_space.spaces[i]))
one_hot(component, self.flattened_input_space[i]))
else:
outs.append(component)
else:
outs.append(tf.reshape(component, [-1, self.flatten[i]]))
outs.append(
tf.cast(
tf.reshape(component, [-1, self.flatten[i]]),
dtype=tf.float32,
))
# Concat all outputs and the non-image inputs.
out = tf.concat(outs, axis=1)
# Push through (optional) FC-stack (this may be an empty stack).
Expand Down
Loading

0 comments on commit 8a06647

Please sign in to comment.