From 8841ad8cb534878cae234ae63cc8c44694c8f312 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 31 Jul 2021 19:02:13 -0400 Subject: [PATCH 01/45] wip --- rllib/agents/trainer.py | 10 +- rllib/evaluation/rollout_worker.py | 19 ++-- rllib/evaluation/sampler.py | 19 +++- rllib/examples/custom_observation_filters.py | 1 - rllib/examples/preprocessing_disabled.py | 111 +++++++++++++++++++ rllib/models/catalog.py | 23 ++-- rllib/models/modelv2.py | 22 ++-- rllib/models/preprocessors.py | 5 +- rllib/models/tests/test_preprocessors.py | 5 +- rllib/models/tf/complex_input_net.py | 28 +++-- rllib/policy/dynamic_tf_policy.py | 7 +- rllib/policy/sample_batch.py | 16 ++- rllib/tests/test_catalog.py | 14 +-- rllib/utils/annotations.py | 16 +++ rllib/utils/tf_ops.py | 15 ++- 15 files changed, 235 insertions(+), 76 deletions(-) create mode 100644 rllib/examples/preprocessing_disabled.py diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index ff5bfff2bc258..3877d0d974028 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -175,15 +175,17 @@ # Tuple[value1, value2]: Clip at value1 and value2. "clip_rewards": None, # If True, RLlib will learn entirely inside a normalized action space - # (0.0 centered with small stddev; only affecting Box components) and - # only unsquash actions (and clip just in case) to the bounds of - # env's action space before sending actions back to the env. + # (0.0 centered with small stddev; only affecting Box components). + # We will unsquash actions (and clip, just in case) to the bounds of + # the env's action space before sending actions back to the env. "normalize_actions": True, # If True, RLlib will clip actions according to the env's bounds # before sending them back to the env. # TODO: (sven) This option should be obsoleted and always be False. "clip_actions": False, # Whether to use "rllib" or "deepmind" preprocessors by default + # Set to None for using no preprocessor. In this case, the model will have + # to handle possibly complex observations from the environment. "preprocessor_pref": "deepmind", # === Debug Settings === @@ -990,7 +992,7 @@ def compute_single_action( state = [] # Check the preprocessor and preprocess, if necessary. pp = self.workers.local_worker().preprocessors[policy_id] - if type(pp).__name__ != "NoPreprocessor": + if pp and type(pp).__name__ != "NoPreprocessor": observation = pp.transform(observation) filtered_observation = self.workers.local_worker().filters[policy_id]( observation, update=False) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index f3a5fa72926a6..9829822679496 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -20,7 +20,7 @@ from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.models import ModelCatalog -from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor +from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ OffPolicyEstimate @@ -150,7 +150,7 @@ def __init__( count_steps_by: str = "env_steps", batch_mode: str = "truncate_episodes", episode_horizon: int = None, - preprocessor_pref: str = "deepmind", + preprocessor_pref: Optional[str] = "deepmind", sample_async: bool = False, compress_observations: bool = False, num_envs: int = 1, @@ -232,8 +232,9 @@ class to use. until the episode completes, and hence batches may contain significant amounts of off-policy data. episode_horizon (int): Whether to stop episodes at this horizon. - preprocessor_pref (str): Whether to prefer RLlib preprocessors - ("rllib") or deepmind ("deepmind") when applicable. + preprocessor_pref (Optional[str]): Whether to use no preprocessor + (None), RLlib preprocessors ("rllib") or deepmind ("deepmind"), + when applicable. sample_async (bool): Whether to compute samples asynchronously in the background, which improves throughput but can cause samples to be slightly off-policy. @@ -387,7 +388,7 @@ def gen_rollouts(): self.count_steps_by: str = count_steps_by self.batch_mode: str = batch_mode self.compress_observations: bool = compress_observations - self.preprocessing_enabled: bool = True + self.preprocessing_enabled: bool = False if preprocessor_pref is None else True self.observation_filter = observation_filter self.last_batch: SampleBatchType = None self.global_vars: dict = None @@ -1337,13 +1338,7 @@ def _build_policy_map( self.preprocessors[name] = preprocessor obs_space = preprocessor.observation_space else: - self.preprocessors[name] = NoPreprocessor(obs_space) - - if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)): - raise ValueError( - "Found raw Tuple|Dict space as input to policy. " - "Please preprocess these observations with a " - "Tuple|DictFlatteningPreprocessor.") + self.preprocessors[name] = None self.policy_map.create_policy(name, orig_cls, obs_space, act_space, conf, merged_conf) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 38073abe1afff..3fdd02ebcfaa2 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -811,10 +811,12 @@ def _process_observations( policy_id: PolicyID = episode.policy_for(agent_id) - prep_obs: EnvObsType = _get_or_raise(worker.preprocessors, - policy_id).transform(raw_obs) - if log_once("prep_obs"): - logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) + preprocessor = _get_or_raise(worker.preprocessors, policy_id) + prep_obs: EnvObsType = raw_obs + if preprocessor is not None: + prep_obs = preprocessor.transform(raw_obs) + if log_once("prep_obs"): + logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)(prep_obs) if log_once("filtered_obs"): @@ -946,10 +948,15 @@ def _process_observations( # types: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): policy_id: PolicyID = new_episode.policy_for(agent_id) - prep_obs: EnvObsType = _get_or_raise( - worker.preprocessors, policy_id).transform(raw_obs) + preproccessor = _get_or_raise( + worker.preprocessors, policy_id) + + prep_obs: EnvObsType = raw_obs + if preproccessor is not None: + prep_obs = preproccessor.transform(raw_obs) filtered_obs: EnvObsType = _get_or_raise( worker.filters, policy_id)(prep_obs) + new_episode._set_last_raw_obs(agent_id, raw_obs) new_episode._set_last_observation(agent_id, filtered_obs) # Add initial obs to buffer. diff --git a/rllib/examples/custom_observation_filters.py b/rllib/examples/custom_observation_filters.py index bcec665200a56..24f4323066b77 100644 --- a/rllib/examples/custom_observation_filters.py +++ b/rllib/examples/custom_observation_filters.py @@ -137,7 +137,6 @@ def __repr__(self): } results = tune.run( - "PG", args.run, config=config, stop={"training_iteration": args.stop_iters}) diff --git a/rllib/examples/preprocessing_disabled.py b/rllib/examples/preprocessing_disabled.py new file mode 100644 index 0000000000000..4cd96a7848d12 --- /dev/null +++ b/rllib/examples/preprocessing_disabled.py @@ -0,0 +1,111 @@ +"""Example of setting preprocessor_pref=None to disable all preprocessing. + +This example shows: + - How a complex observation space from the env is handled directly by the + model. + - Complex observations are flattened into lists of tensors and as such + stored by the SampleCollectors. + - This has the advantage that preprocessing happens in batched fashion + (in the model). +""" +import argparse +from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple +import numpy as np +import os + +import ray +from ray import tune + + +def get_cli_args(): + """Create CLI parser and return parsed arguments""" + parser = argparse.ArgumentParser() + + # example-specific args + parser.add_argument( + "--no-attention", + action="store_true", + help="Do NOT use attention. For comparison: The agent will not learn.") + + # general args + parser.add_argument( + "--run", default="PPO", help="The RLlib-registered algorithm to use.") + parser.add_argument("--num-cpus", type=int, default=3) + parser.add_argument( + "--framework", + choices=["tf", "tf2", "tfe", "torch"], + default="tf", + help="The DL framework specifier.") + parser.add_argument( + "--stop-iters", + type=int, + default=200, + help="Number of iterations to train.") + parser.add_argument( + "--stop-timesteps", + type=int, + default=500000, + help="Number of timesteps to train.") + parser.add_argument( + "--stop-reward", + type=float, + default=80.0, + help="Reward at which we stop training.") + parser.add_argument( + "--as-test", + action="store_true", + help="Whether this script should be run as a test: --stop-reward must " + "be achieved within --stop-timesteps AND --stop-iters.") + parser.add_argument( + "--no-tune", + action="store_true", + help="Run without Tune using a manual train loop instead. Here," + "there is no TensorBoard support.") + parser.add_argument( + "--local-mode", + action="store_true", + help="Init Ray in local mode for easier debugging.") + + args = parser.parse_args() + print(f"Running with following CLI args: {args}") + return args + + +if __name__ == "__main__": + args = get_cli_args() + + ray.init(local_mode=args.local_mode) + + config = { + "env": "ray.rllib.examples.env.random_env.RandomEnv", + "env_config": { + "config": { + "observation_space": Dict({ + "a": Discrete(2), + "b": Dict({ + "ba": Discrete(3), + "bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)}), + "c": Tuple((MultiDiscrete([2, 3]), Discrete(2))), + "d": Box(-1.0, 1.0, (2, ), dtype=np.int32), + }), + }, + }, + # Set this to None to enforce no preprocessors being used. + # Complex observations now arrive directly in the model as + # structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]} + # for obs-space=Dict(a=..., b=Tuple(..., ...)). + "preprocessor_pref": None, + # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. + "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)), + "framework": args.framework, + } + + stop = { + "training_iteration": args.stop_iters, + "timesteps_total": args.stop_timesteps, + "episode_reward_mean": args.stop_reward, + } + + results = tune.run(args.run, config=config, stop=stop, verbose=2) + + ray.shutdown() diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index bb84b92785464..0aa7dffbc8308 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -18,7 +18,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchDeterministic, TorchDiagGaussian, \ TorchMultiActionDistribution, TorchMultiCategorical -from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \ deprecation_warning from ray.rllib.utils.error import UnsupportedSpaceException @@ -670,7 +670,7 @@ def get_preprocessor(env: gym.Env, options) @staticmethod - @DeveloperAPI + @Deprecated def get_preprocessor_for_space(observation_space: gym.Space, options: dict = None) -> Preprocessor: """Returns a suitable preprocessor for the given observation space. @@ -709,7 +709,7 @@ def get_preprocessor_for_space(observation_space: gym.Space, return prep @staticmethod - @PublicAPI + @Deprecated def register_custom_preprocessor(preprocessor_name: str, preprocessor_class: type) -> None: """Register a custom preprocessor class by name. @@ -811,14 +811,15 @@ def _get_v2_model_class(input_space: gym.Space, # disabled. num_framestacks = model_config.get("num_framestacks", "auto") - # Tuple space, where at least one sub-space is image. - # -> Complex input model. + # Complex space, where at least one sub-space is image. + # -> Complex input model (which auto-flattens everything, but correctly + # processes image components with default CNN stacks). space_to_check = input_space if not hasattr( input_space, "original_space") else input_space.original_space if isinstance(input_space, - Tuple) or (isinstance(space_to_check, Tuple) and any( + (Dict, Tuple)) or (isinstance(space_to_check, (Dict, Tuple)) and any( isinstance(s, Box) and len(s.shape) >= 2 - for s in space_to_check.spaces)): + for s in tree.flatten(space_to_check.spaces))): return ComplexNet # Single, flattenable/one-hot-able space -> Simple FCNet. @@ -876,6 +877,14 @@ def _validate_config(config: ModelConfigDict, framework: str) -> None: Raises: ValueError: If something is wrong with the given config. """ + # Soft-deprecate custom preprocessors. + if config.get("custom_preprocessor") is not None: + deprecation_warning( + old="model.custom_preprocessor", + new="gym.ObservationWrapper around your env", + error=False, + ) + if config.get("use_attention") and config.get("use_lstm"): raise ValueError("Only one of `use_lstm` or `use_attention` may " "be set to True!") diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 5921c8b2c2662..f5e0709a45e76 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -216,18 +216,24 @@ def __call__( input_dict["is_training"] = input_dict.is_training else: restored = input_dict.copy() - restored["obs"] = restore_original_dimensions( - input_dict["obs"], self.obs_space, self.framework) - try: - if len(input_dict["obs"].shape) > 2: - restored["obs_flat"] = flatten(input_dict["obs"], - self.framework) - else: + + if hasattr(self.obs_space, "original_space"): + restored["obs"] = restore_original_dimensions( + input_dict["obs"], self.obs_space, self.framework) + try: + if len(input_dict["obs"].shape) > 2: + restored["obs_flat"] = flatten(input_dict["obs"], + self.framework) + else: + restored["obs_flat"] = input_dict["obs"] + except AttributeError: restored["obs_flat"] = input_dict["obs"] - except AttributeError: + else: restored["obs_flat"] = input_dict["obs"] + with self.context(): res = self.forward(restored, state or [], seq_lens) + if ((not isinstance(res, list) and not isinstance(res, tuple)) or len(res) != 2): raise ValueError( diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 9976edbf10b86..06ecbd6508fb4 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -4,7 +4,7 @@ import gym from typing import Any, List -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.annotations import Deprecated, override, PublicAPI from ray.rllib.utils.spaces.repeated import Repeated from ray.rllib.utils.typing import TensorType from ray.rllib.utils.images import resize @@ -177,6 +177,7 @@ def write(self, observation: TensorType, array: np.ndarray, array[offset:offset + self.size] = self.transform(observation) +@Deprecated class NoPreprocessor(Preprocessor): @override(Preprocessor) def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: @@ -331,7 +332,7 @@ def get_preprocessor(space: gym.Space) -> type: elif isinstance(space, Repeated): preprocessor = RepeatedValuesPreprocessor else: - preprocessor = NoPreprocessor + preprocessor = None return preprocessor diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index 4ce7b73e7e749..db7629fee5863 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -5,16 +5,13 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \ - get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \ + get_preprocessor, TupleFlatteningPreprocessor, \ OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor from ray.rllib.utils.test_utils import check class TestPreprocessors(unittest.TestCase): def test_gym_preprocessors(self): - p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0")) - self.assertEqual(type(p1), NoPreprocessor) - p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v0")) self.assertEqual(type(p2), OneHotPreprocessor) diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index 235701854b33c..cddfc93510512 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -1,5 +1,6 @@ -from gym.spaces import Box, Discrete, Tuple +from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple import numpy as np +import tree from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions @@ -9,6 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.spaces.space_utils import flatten_space from ray.rllib.utils.tf_ops import one_hot tf1, tf, tfv = try_import_tf() @@ -31,21 +33,22 @@ class ComplexInputNetwork(TFModelV2): def __init__(self, obs_space, action_space, num_outputs, model_config, name): - # TODO: (sven) Support Dicts as well. self.original_space = obs_space.original_space if \ hasattr(obs_space, "original_space") else obs_space - assert isinstance(self.original_space, (Tuple)), \ - "`obs_space.original_space` must be Tuple!" + assert isinstance(self.original_space, (Dict, Tuple)), \ + "`obs_space.original_space` must be [Dict|Tuple]!" super().__init__(self.original_space, action_space, num_outputs, model_config, name) + self.flattened_input_space = flatten_space(self.original_space) + # Build the CNN(s) given obs_space's image components. self.cnns = {} self.one_hot = {} self.flatten = {} concat_size = 0 - for i, component in enumerate(self.original_space): + for i, component in enumerate(self.flattened_input_space): # Image space. if len(component.shape) == 3: config = { @@ -64,11 +67,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name="cnn_{}".format(i)) concat_size += cnn.num_outputs self.cnns[i] = cnn - # Discrete inputs -> One-hot encode. + # Discrete|MultiDiscrete inputs -> One-hot encode. elif isinstance(component, Discrete): self.one_hot[i] = True concat_size += component.n - # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers). + elif isinstance(component, MultiDiscrete): + self.one_hot[i] = True + concat_size += sum(component.nvec) # Everything else (1D Box). else: self.flatten[i] = int(np.product(component.shape)) @@ -123,18 +128,21 @@ def forward(self, input_dict, state, seq_lens): self.obs_space, "tf") # Push image observations through our CNNs. outs = [] - for i, component in enumerate(orig_obs): + for i, component in enumerate(tree.flatten(orig_obs)): if i in self.cnns: cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: if component.dtype in [tf.int32, tf.int64, tf.uint8]: outs.append( - one_hot(component, self.original_space.spaces[i])) + one_hot(component, self.flattened_input_space[i])) else: outs.append(component) else: - outs.append(tf.reshape(component, [-1, self.flatten[i]])) + outs.append(tf.cast( + tf.reshape(component, [-1, self.flatten[i]]), + dtype=tf.float32, + )) # Concat all outputs and the non-image inputs. out = tf.concat(outs, axis=1) # Push through (optional) FC-stack (this may be an empty stack). diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 21e6ab82bf22c..10de14e50d0e7 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -539,6 +539,7 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, # Skip action dist inputs placeholder (do later). elif view_col == SampleBatch.ACTION_DIST_INPUTS: continue + # This is a tower, input placeholders already exist. elif view_col in existing_inputs: input_dict[view_col] = existing_inputs[view_col] # All others. @@ -547,10 +548,14 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, if view_req.used_for_training: # Create a +time-axis placeholder if the shift is not an # int (range or list of ints). + flatten = view_col != SampleBatch.OBS or \ + self.config["preprocessor_pref"] is not None input_dict[view_col] = get_placeholder( space=view_req.space, name=view_col, - time_axis=time_axis) + time_axis=time_axis, + flatten=flatten, + ) dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a56f11e7f7611..14762242d81dc 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -97,12 +97,20 @@ def __init__(self, *args, **kwargs): copy_ = {k: v for k, v in self.items() if k != "seq_lens"} for k, v in copy_.items(): assert isinstance(k, str), self + + # Convert lists of int|float into numpy arrays. + if isinstance(v, list) and isinstance(v[0], (int, float)): + self[k] = np.array(v) + + # Try to infer the "length" of the SampleBatch by finding the first + # value that is actually a ndarray/tensor. This would fail if + # all values are nested dicts/tuples of more complex underlying + # structures. len_ = len(v) if isinstance( v, (list, np.ndarray)) or (torch and torch.is_tensor(v)) else None - lengths.append(len_) - if isinstance(v, list): - self[k] = np.array(v) + if len_: + lengths.append(len_) if self.get("seq_lens") is not None and \ not (tf and tf.is_tensor(self["seq_lens"])) and \ @@ -145,7 +153,7 @@ def concat_samples(samples: List["SampleBatch"]) -> \ if s.get("seq_lens") is not None: seq_lens.extend(s["seq_lens"]) - # If we don't have any samples (no or only empty SampleBatches), + # If we don't have any samples (0 or only empty SampleBatches), # return an empty SampleBatch here. if len(concat_samples) == 0: return SampleBatch() diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index bbd1ec1bbbaad..b096bd9ae5852 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -6,7 +6,7 @@ import ray from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS -from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor +from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.models.tf.tf_action_dist import MultiActionDistribution, \ TFActionDistribution from ray.rllib.models.tf.tf_modelv2 import TFModelV2 @@ -72,18 +72,6 @@ class TestModelCatalog(unittest.TestCase): def tearDown(self): ray.shutdown() - def test_custom_preprocessor(self): - ray.init(object_store_memory=1000 * 1024 * 1024) - ModelCatalog.register_custom_preprocessor("foo", CustomPreprocessor) - ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2) - env = gym.make("CartPole-v0") - p1 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "foo"}) - self.assertEqual(str(type(p1)), str(CustomPreprocessor)) - p2 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "bar"}) - self.assertEqual(str(type(p2)), str(CustomPreprocessor2)) - p3 = ModelCatalog.get_preprocessor(env) - self.assertEqual(type(p3), NoPreprocessor) - def test_default_models(self): ray.init(object_store_memory=1000 * 1024 * 1024) diff --git a/rllib/utils/annotations.py b/rllib/utils/annotations.py index 4713e2a3146c0..c6408c949b881 100644 --- a/rllib/utils/annotations.py +++ b/rllib/utils/annotations.py @@ -45,3 +45,19 @@ def DeveloperAPI(obj): """ return obj + + +def Deprecated(obj): + """Annotation for documenting a (soon-to-be) deprecated method. + + Methods tagged with this decorator should produce a + `ray.rllib.utils.deprecation.deprecation_warning(old=..., error=False)` + to not break existing code at this point. + In a next major release, this warning can then be made an error + (error=True), which means at this point that the method is already + no longer supported but will still inform the user about the + deprecation event. + In a further major release, the method should be erased. + """ + + return obj diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 7ce9bfdf44cf5..6bfad17fc9d36 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -4,6 +4,7 @@ import tree # pip install dm_tree from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space tf1, tf, tfv = try_import_tf() @@ -54,12 +55,18 @@ def get_gpu_devices(): return gpus -def get_placeholder(*, space=None, value=None, name=None, time_axis=False): +def get_placeholder(*, space=None, value=None, name=None, time_axis=False, flatten=True): from ray.rllib.models.catalog import ModelCatalog if space is not None: if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): - return ModelCatalog.get_action_placeholder(space, None) + if flatten: + return ModelCatalog.get_action_placeholder(space, None) + else: + return tree.map_structure_with_path( + lambda path, component: get_placeholder(space=component, name=name + "." + ".".join([str(p) for p in path])), + get_base_struct_from_space(space), + ) return tf1.placeholder( shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, dtype=tf.float32 if space.dtype == np.float64 else space.dtype, @@ -111,10 +118,10 @@ def huber_loss(x, delta=1.0): def one_hot(x, space): if isinstance(space, Discrete): - return tf.one_hot(x, space.n) + return tf.one_hot(x, space.n, dtype=tf.float32) elif isinstance(space, MultiDiscrete): return tf.concat( - [tf.one_hot(x[:, i], n) for i, n in enumerate(space.nvec)], + [tf.one_hot(x[:, i], n, dtype=tf.float32) for i, n in enumerate(space.nvec)], axis=-1) else: raise ValueError("Unsupported space for `one_hot`: {}".format(space)) From 122bcbc2aeed5c60e7a3c311d330d33d45ba7a51 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sun, 1 Aug 2021 08:12:06 -0400 Subject: [PATCH 02/45] wip --- rllib/evaluation/rollout_worker.py | 6 +- rllib/policy/sample_batch.py | 475 ++++++++++++++++++------ rllib/policy/tests/test_sample_batch.py | 100 ++++- 3 files changed, 458 insertions(+), 123 deletions(-) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index f3a5fa72926a6..f35dabbc94448 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -769,10 +769,8 @@ def sample(self) -> SampleBatchType: logger.info("Completed sample batch:\n\n{}\n".format( summarize(batch))) - if self.compress_observations == "bulk": - batch.compress(bulk=True) - elif self.compress_observations: - batch.compress() + if self.compress_observations: + batch.compress(bulk=self.compress_observations == "bulk") if self.fake_sampler: self.last_batch = batch diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a56f11e7f7611..d164e66328077 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -2,6 +2,7 @@ import numpy as np import sys import itertools +import tree # pip install dm_tree from typing import Dict, List, Optional, Set, Union from ray.util import log_once @@ -79,16 +80,18 @@ def __init__(self, *args, **kwargs): self.get_interceptor = None # Clear out None seq-lens. - if self.get("seq_lens") is None or self.get("seq_lens") == []: + seq_lens_ = self.get("seq_lens") + if seq_lens_ is None or \ + (isinstance(seq_lens_, list) and len(seq_lens_) == 0): self.pop("seq_lens", None) # Numpyfy seq_lens if list. - elif isinstance(self.get("seq_lens"), list): - self["seq_lens"] = np.array(self["seq_lens"], dtype=np.int32) + elif isinstance(seq_lens_, list): + self["seq_lens"] = seq_lens_ = np.array(seq_lens_, dtype=np.int32) - if self.max_seq_len is None and self.get("seq_lens") is not None and \ - not (tf and tf.is_tensor(self["seq_lens"])) and \ - len(self["seq_lens"]) > 0: - self.max_seq_len = max(self["seq_lens"]) + if self.max_seq_len is None and seq_lens_ is not None and \ + not (tf and tf.is_tensor(seq_lens_)) and \ + len(seq_lens_) > 0: + self.max_seq_len = max(seq_lens_) if self.is_training is None: self.is_training = self.pop("is_training", False) @@ -97,12 +100,22 @@ def __init__(self, *args, **kwargs): copy_ = {k: v for k, v in self.items() if k != "seq_lens"} for k, v in copy_.items(): assert isinstance(k, str), self + + # TODO: Drop support for lists as values. + # Convert lists of int|float into numpy arrays make sure all data + # has same length. + if isinstance(v, list) and isinstance(v[0], (int, float)): + self[k] = np.array(v) + + # Try to infer the "length" of the SampleBatch by finding the first + # value that is actually a ndarray/tensor. This would fail if + # all values are nested dicts/tuples of more complex underlying + # structures. len_ = len(v) if isinstance( v, (list, np.ndarray)) or (torch and torch.is_tensor(v)) else None - lengths.append(len_) - if isinstance(v, list): - self[k] = np.array(v) + if len_: + lengths.append(len_) if self.get("seq_lens") is not None and \ not (tf and tf.is_tensor(self["seq_lens"])) and \ @@ -118,17 +131,26 @@ def __len__(self): @staticmethod @PublicAPI - def concat_samples(samples: List["SampleBatch"]) -> \ - Union["SampleBatch", "MultiAgentBatch"]: - """Concatenates n data dicts or MultiAgentBatches. + def concat_samples( + samples: Union[List["SampleBatch"], List["MultiAgentBatch"]], + ) -> Union["SampleBatch", "MultiAgentBatch"]: + """Concatenates n SampleBatches or MultiAgentBatches. Args: - samples (List[Dict[str, TensorType]]]): List of dicts of data - (numpy). + samples (Union[List[SampleBatch], List[MultiAgentBatch]]): List of + SampleBatches or MultiAgentBatches to be concatenated. Returns: - Union[SampleBatch, MultiAgentBatch]: A new (compressed) + Union[SampleBatch, MultiAgentBatch]: A new (concatenated) SampleBatch or MultiAgentBatch. + + Examples: + >>> b1 = SampleBatch({"a": np.array([1, 2]), + ... "b": np.array([10, 11])}) + >>> b2 = SampleBatch({"a": np.array([3]), + ... "b": np.array([12])}) + >>> print(SampleBatch.concat_samples([b1, b2])) + {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])} """ if isinstance(samples[0], MultiAgentBatch): return MultiAgentBatch.concat_samples(samples) @@ -151,11 +173,13 @@ def concat_samples(samples: List["SampleBatch"]) -> \ return SampleBatch() # Collect the concat'd data. - concatd_data = {} - for k in concat_samples[0].keys(): - concatd_data[k] = concat_aligned( - [s[k] for s in concat_samples], - time_major=concat_samples[0].time_major) + try: + concatd_data = tree.map_structure( + lambda *s: concat_aligned(s, concat_samples[0].time_major), + *concat_samples) + except TypeError: + raise TypeError(f"Cannot concat `samples` ({samples})! " + "Structures don't match.") # Return a new (concat'd) SampleBatch. return SampleBatch( @@ -168,7 +192,7 @@ def concat_samples(samples: List["SampleBatch"]) -> \ @PublicAPI def concat(self, other: "SampleBatch") -> "SampleBatch": - """Returns a new SampleBatch with each data column concatenated. + """Concatenates `other` to this one and returns a new SampleBatch. Args: other (SampleBatch): The other SampleBatch object to concat to this @@ -179,20 +203,12 @@ def concat(self, other: "SampleBatch") -> "SampleBatch": to `self`. Examples: - >>> b1 = SampleBatch({"a": [1, 2]}) - >>> b2 = SampleBatch({"a": [3, 4, 5]}) + >>> b1 = SampleBatch({"a": np.array([1, 2])}) + >>> b2 = SampleBatch({"a": np.array([3, 4, 5])}) >>> print(b1.concat(b2)) - {"a": [1, 2, 3, 4, 5]} + {"a": np.array([1, 2, 3, 4, 5])} """ - - if self.keys() != other.keys(): - raise ValueError( - "SampleBatches to concat must have same columns! {} vs {}". - format(list(self.keys()), list(other.keys()))) - out = {} - for k in self.keys(): - out[k] = concat_aligned([self[k], other[k]]) - return SampleBatch(out) + return self.concat_samples([self, other]) @PublicAPI def copy(self, shallow: bool = False) -> "SampleBatch": @@ -204,13 +220,16 @@ def copy(self, shallow: bool = False) -> "SampleBatch": Returns: SampleBatch: A (deep) copy of this SampleBatch object. """ + data = tree.map_structure( + lambda v: (np.array(v, copy=not shallow) + if isinstance(v, np.ndarray) else v), + self) + copy_ = SampleBatch( - { - k: np.array(v, copy=not shallow) - if isinstance(v, np.ndarray) else v - for (k, v) in self.items() - }, - seq_lens=self.get("seq_lens"), + data, + _time_major=self.time_major, + _zero_padded=self.zero_padded, + _max_seq_len=self.max_seq_len, ) copy_.set_get_interceptor(self.get_interceptor) return copy_ @@ -219,24 +238,33 @@ def copy(self, shallow: bool = False) -> "SampleBatch": def rows(self) -> Dict[str, TensorType]: """Returns an iterator over data rows, i.e. dicts with column values. + Note that if `seq_lens` is set in self, we set it to [1] in the rows. + Yields: Dict[str, TensorType]: The column values of the row in this iteration. Examples: - >>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> batch = SampleBatch({ + ... "a": [1, 2, 3], + ... "b": [4, 5, 6], + ... "seq_lens": [1, 2] + ... }) >>> for row in batch.rows(): print(row) - {"a": 1, "b": 4} - {"a": 2, "b": 5} - {"a": 3, "b": 6} + {"a": 1, "b": 4, "seq_lens": [1]} + {"a": 2, "b": 5, "seq_lens": [1]} + {"a": 3, "b": 6, "seq_lens": [1]} """ + # Do we add seq_lens=[1] to each row? + seq_lens = None if self.get("seq_lens") is None else np.array([1]) + for i in range(self.count): - row = {} - for k in self.keys(): - row[k] = self[k][i] - yield row + yield tree.map_structure_with_path( + lambda p, v: v[i] if p[0] != "seq_lens" else seq_lens, + self, + ) @PublicAPI def columns(self, keys: List[str]) -> List[any]: @@ -255,6 +283,7 @@ def columns(self, keys: List[str]) -> List[any]: [[1], [2]] """ + # TODO: (sven) Make this work for nested data as well. out = [] for k in keys: out.append(self[k]) @@ -262,45 +291,92 @@ def columns(self, keys: List[str]) -> List[any]: @PublicAPI def shuffle(self) -> None: - """Shuffles the rows of this batch in-place.""" + """Shuffles the rows of this batch in-place. + + Returns: + SampleBatch: This very (now shuffled) SampleBatch. + + Raises: + ValueError: If self["seq_lens"] is defined. + + Examples: + >>> batch = SampleBatch({"a": [1, 2, 3, 4]}) + >>> print(batch.shuffle()) + {"a": [4, 1, 3, 2]} + """ + + # Shuffling the data when we have `seq_lens` defined is probably + # a bad idea! + if self.get("seq_lens") is not None: + raise ValueError( + "SampleBatch.shuffle not possible when your data has " + "`seq_lens` defined!") + # Get a permutation over the single items once and use the same + # permutation for all the data (otherwise, data would become + # meaningless). permutation = np.random.permutation(self.count) - for key, val in self.items(): - self[key] = val[permutation] + + def _permutate_in_place(path, value): + curr = self + for i, p in enumerate(path): + if i == len(path) - 1: + curr[p] = value[permutation] + curr = curr[p] + + tree.map_structure_with_path(_permutate_in_place, self) + + return self @PublicAPI def split_by_episode(self) -> List["SampleBatch"]: - """Splits this batch's data by `eps_id`. + """Splits by `eps_id` column and returns list of new batches. Returns: List[SampleBatch]: List of batches, one per distinct episode. + + Raises: + KeyError: If the `eps_id` AND `dones` columns are not present. + + Examples: + >>> batch = SampleBatch({"a": [1, 2, 3], "eps_id": [0, 0, 1]}) + >>> print(batch.split_by_episode()) + [{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}] """ # No eps_id in data -> Make sure there are no "dones" in the middle # and add eps_id automatically. if SampleBatch.EPS_ID not in self: + # TODO: (sven) Shouldn't we rather split by DONEs then and not + # add fake eps-ids (0s) at all? if SampleBatch.DONES in self: assert not any(self[SampleBatch.DONES][:-1]) self[SampleBatch.EPS_ID] = np.repeat(0, self.count) return [self] + # Produce a new slice whenever we find a new episode ID. slices = [] cur_eps_id = self[SampleBatch.EPS_ID][0] offset = 0 for i in range(self.count): next_eps_id = self[SampleBatch.EPS_ID][i] if next_eps_id != cur_eps_id: - slices.append(self.slice(offset, i)) + slices.append(self[offset:i]) offset = i cur_eps_id = next_eps_id - slices.append(self.slice(offset, self.count)) + # Add final slice. + slices.append(self[offset:self.count]) + + # TODO: (sven) Are these checks necessary? Should be all ok according + # to above logic. for s in slices: slen = len(set(s[SampleBatch.EPS_ID])) assert slen == 1, (s, slen) assert sum(s.count for s in slices) == self.count, (slices, self.count) + return slices - @PublicAPI + # TODO: (sven) Deprecated method. def slice(self, start: int, end: int, state_start=None, state_end=None) -> "SampleBatch": """Returns a slice of the row data of this batch (w/o copying). @@ -313,6 +389,10 @@ def slice(self, start: int, end: int, state_start=None, SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ + if log_once("SampleBatch.slice"): + deprecation_warning( + "SampleBatch.slice()", "SampleBatch[start:stop]", error=False) + if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: if start < 0: data = { @@ -387,90 +467,157 @@ def slice(self, start: int, end: int, state_start=None, _time_major=self.time_major) @PublicAPI - def timeslices(self, size=None, num_slices=None, + def timeslices(self, + size: Optional[int] = None, + num_slices: Optional[int] = None, k: Optional[int] = None) -> List["SampleBatch"]: """Returns SampleBatches, each one representing a k-slice of this one. Will start from timestep 0 and produce slices of size=k. Args: - size (int): The size (in timesteps) of each returned SampleBatch. - num_slices (int): The number of slices to produce. + size (Optional[int]): The size (in timesteps) of each returned + SampleBatch. + num_slices (Optional[int]): The number of slices to produce. k (int): Obsoleted: Use size or num_slices instead! The size (in timesteps) of each returned SampleBatch. Returns: - List[SampleBatch]: The list of (new) SampleBatches (each one of - size k). + List[SampleBatch]: The list of `num_slices` (new) SampleBatches + or n (new) SampleBatches each one of size `size`. """ if size is None and num_slices is None: - deprecation_warning("k", "size and num_slices") + deprecation_warning("k", "size or num_slices") assert k is not None size = k - slices, state_slices = self._get_slice_indices(size) - if len(state_slices) == 0: - timeslices = [self.slice(i, j) for i, j in slices] + if size is None: + assert isinstance(num_slices, int) + + slices = [] + left = len(self) + start = 0 + while left: + len_ = left // (num_slices - len(slices)) + stop = start + len_ + slices.append(self[start:stop]) + left -= len_ + start = stop + + return slices + else: - timeslices = [ - self.slice(i, j, si, sj) for (i, j), (si, sj) in slices - ] - return timeslices + assert isinstance(size, int) + + slices = [] + left = len(self) + start = 0 + while left: + stop = start + size + slices.append(self[start:stop]) + left -= size + start = stop + + return slices + + # TODO: (sven) Deprecate in favor of right_zero_pad. + def zero_pad(self, max_seq_len, exclude_states=True): + if log_once("SampleBatch.zero_pad"): + deprecation_warning( + old="SampleBatch.zero_pad", + new="SampleBatch.right_zero_pad", + error=False, + ) + return self.right_zero_pad(max_seq_len, exclude_states) - def zero_pad(self, max_seq_len: int, exclude_states: bool = True): - """Left zero-pad the data in this SampleBatch in place. + def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True): + """Right (adding zeros at end) zero-pads this SampleBatch in-place. This will set the `self.zero_padded` flag to True and `self.max_seq_len` to the given `max_seq_len` value. Args: max_len (int): The max (total) length to zero pad to. - exclude_states (bool): If False, also zero-pad all `state_in_x` - data. If False, leave `state_in_x` keys as-is. + exclude_states (bool): If True, also right-zero-pad all + `state_in_x` data. If False, leave `state_in_x` keys + as-is. + + Returns: + SampleBatch: This very (now right-zero-padded) SampleBatch. + + Raises: + ValueError: If self.seq_lens is None (not defined). + + Examples: + >>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]}) + >>> print(batch.right_zero_pad(max_seq_len=4)) + {"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]} + + >>> batch = SampleBatch({"a": [1, 2, 3], + ... "state_in_0": [1.0, 3.0], + ... "seq_lens": [1, 2]}) + >>> print(batch.right_zero_pad(max_seq_len=5)) + {"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0], + "state_in_0": [1.0, 3.0], # <- all state-ins remain as-is + "seq_lens": [1, 2]} """ - for col in self.keys(): + seq_lens = self.get("seq_lens") + if seq_lens is None: + raise ValueError( + "Cannot right-zero-pad SampleBatch if no `seq_lens` field " + "present! SampleBatch={self}") + + length = len(seq_lens) * max_seq_len + + def _zero_pad_in_place(path, value): # Skip "state_in_..." columns and "seq_lens". - if (exclude_states is True and col.startswith("state_in_")) or \ - col == "seq_lens": - continue - - f = self[col] - # Save unnecessary copy. - if not isinstance(f, np.ndarray): - f = np.array(f) - # Already good length, can skip. - if f.shape[0] == max_seq_len: - continue + if (exclude_states is True and path[0].startswith("state_in_")) \ + or path[0] == "seq_lens": + return # Generate zero-filled primer of len=max_seq_len. - length = len(self["seq_lens"]) * max_seq_len - if f.dtype == np.object or f.dtype.type is np.str_: + if value.dtype == np.object or value.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. - f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype) + f_pad = np.zeros( + (length, ) + np.shape(value)[1:], dtype=value.dtype) # Fill primer with data. f_pad_base = f_base = 0 for len_ in self["seq_lens"]: - f_pad[f_pad_base:f_pad_base + len_] = f[f_base:f_base + len_] + f_pad[f_pad_base:f_pad_base + len_] = value[f_base:f_base + + len_] f_pad_base += max_seq_len f_base += len_ - assert f_base == len(f), f - # Update our data. - self[col] = f_pad + assert f_base == len(value), value + + # Update our data in-place. + curr = self + for i, p in enumerate(path): + if i == len(path) - 1: + curr[p] = f_pad + curr = curr[p] + + tree.map_structure_with_path(_zero_pad_in_place, self) # Set flags to indicate, we are now zero-padded (and to what extend). self.zero_padded = True self.max_seq_len = max_seq_len + return self + @PublicAPI def size_bytes(self) -> int: - """ + """Returns sum over number of bytes of all data buffers. + + For numpy arrays, we use `.nbytes`. For all other value types, we use + sys.getsizeof(...). + Returns: int: The overall size in bytes of the data buffer (all columns). """ return sum( v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v) - for v in self.values()) + for v in tree.flatten(self)) def get(self, key, default=None): try: @@ -479,15 +626,20 @@ def get(self, key, default=None): return default @PublicAPI - def __getitem__(self, key: str) -> TensorType: - """Returns one column (by key) from the data. + def __getitem__(self, key: Union[str, slice]) -> TensorType: + """Returns one column (by key) from the data or a sliced new batch. Args: - key (str): The key (column name) to return. + key (Union[str, slice]): The key (column name) to return or + a slice object for slicing this SampleBatch. Returns: - TensorType: The data under the given key. + TensorType: The data under the given key or a sliced version of + this batch. """ + if isinstance(key, slice): + return self._slice(key) + if not hasattr(self, key) and key in self: self.accessed_keys.add(key) @@ -545,13 +697,26 @@ def compress(self, is the batch size. columns (Set[str]): The columns to compress. Default: Only compress the obs and new_obs columns. + + Returns: + SampleBatch: This very (now compressed) SampleBatch. """ - for key in columns: - if key in self.keys(): - if bulk: - self[key] = pack(self[key]) - else: - self[key] = np.array([pack(o) for o in self[key]]) + + def _compress_in_place(path, value): + if path[0] not in columns: + return + curr = self + for i, p in enumerate(path): + if i == len(path) - 1: + if bulk: + curr[p] = pack(value) + else: + curr[p] = np.array([pack(o) for o in value]) + curr = curr[p] + + tree.map_structure_with_path(_compress_in_place, self) + + return self @DeveloperAPI def decompress_if_needed(self, @@ -564,25 +729,106 @@ def decompress_if_needed(self, decompress the obs and new_obs columns. Returns: - SampleBatch: This very SampleBatch. + SampleBatch: This very (now uncompressed) SampleBatch. """ - for key in columns: - if key in self.keys(): - arr = self[key] - if is_compressed(arr): - self[key] = unpack(arr) - elif len(arr) > 0 and is_compressed(arr[0]): - self[key] = np.array([unpack(o) for o in self[key]]) + + def _decompress_in_place(path, value): + if path[0] not in columns: + return + curr = self + for i, p in enumerate(path): + if i == len(path) - 1: + # Bulk compressed. + if is_compressed(value): + curr[p] = unpack(value) + # Non bulk compressed. + elif len(value) > 0 and is_compressed(value): + curr[p] = np.array([unpack(o) for o in value]) + curr = curr[p] + + tree.map_structure_with_path(_decompress_in_place, self) + return self @DeveloperAPI def set_get_interceptor(self, fn): + # If get-interceptor changes, must erase old intercepted values. + if fn is not self.get_interceptor: + self.intercepted_values = {} self.get_interceptor = fn def __repr__(self): - return "SampleBatch({})".format(list(self.keys())) + keys = list(self.keys()) + if self.get("seq_lens") is None: + return f"SampleBatch({self.count}: {keys})" + else: + keys.remove("seq_lens") + return f"SampleBatch({self.count} " \ + f"(seqs={len(self['seq_lens'])}): {keys})" + + def _slice(self, slice_: slice): + """Helper method to handle SampleBatch slicing using a slice object. + + The returned SampleBatch uses the same underlying data object as + `self`, so changing the slice will also change `self`. + Note that only zero or positive bounds are allowed for both start + and stop values. The slice step must be 1 (or None, which is the + same). + + Args: + slice_ (slice): The python slice object to slice by. + + Returns: + SampleBatch: A new SampleBatch, however "linking" into the same + data (sliced) as self. + """ + assert slice_.start >= 0 and slice_.stop >= 0 and \ + slice_.step in [1, None] + if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: + # Build our slice-map if not done already. + if not self._slice_map: + sum_ = 0 + for i, l in enumerate(self["seq_lens"]): + for _ in range(l): + self._slice_map.append((i, sum_)) + sum_ += l + self._slice_map.append((len(self["seq_lens"]), sum_)) + + start_seq_len, start = self._slice_map[slice_.start] + stop_seq_len, stop = self._slice_map[slice_.stop] + if self.zero_padded: + start = start_seq_len * self.max_seq_len + stop = stop_seq_len * self.max_seq_len + + def map_(path, value): + if path[0] != "seq_lens" and not path[0].startswith( + "state_in_"): + return value[start:stop] + else: + return value[start_seq_len:stop_seq_len] + + data = tree.map_structure_with_path(map_, self) + return SampleBatch( + data, + _is_training=self.is_training, + _time_major=self.time_major, + ) + else: + start, stop = slice_.start, slice_.stop + data = tree.map_structure(lambda value: value[start:stop], self) + return SampleBatch( + data, + _is_training=self.is_training, + _time_major=self.time_major, + ) + + # TODO: (sven) Deprecated method. def _get_slice_indices(self, slice_size): + + if log_once("SampleBatch._get_slice_indices"): + deprecation_warning("SampleBatch._get_slice_indices", error=False) + data_slices = [] data_slices_states = [] if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: @@ -630,9 +876,8 @@ def _get_slice_indices(self, slice_size): # TODO: deprecate @property def data(self): - if log_once("SampleBatch.data"): - deprecation_warning( - old="SampleBatch.data[..]", new="SampleBatch[..]", error=False) + deprecation_warning( + old="SampleBatch.data[..]", new="SampleBatch[..]", error=True) return self # TODO: (sven) Experimental method. diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index c858fc3c4198c..cf17e0d8d2e23 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -3,6 +3,8 @@ import ray from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.compression import is_compressed +from ray.rllib.utils.test_utils import check class TestSampleBatch(unittest.TestCase): @@ -14,6 +16,18 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: ray.shutdown() + def test_len_and_size_bytes(self): + s1 = SampleBatch({ + "a": np.array([1, 2, 3]), + "b": { + "c": np.array([4, 5, 6]) + }, + "seq_lens": [1, 2], + }) + check(len(s1), 3) + check(s1.size_bytes(), + s1["a"].nbytes + s1["b"]["c"].nbytes + s1["seq_lens"].nbytes) + def test_dict_properties_of_sample_batches(self): base_dict = { "a": np.array([1, 2, 3]), @@ -21,10 +35,6 @@ def test_dict_properties_of_sample_batches(self): "c": True, } batch = SampleBatch(base_dict) - try: - SampleBatch(base_dict) - except AssertionError: - pass # expected keys_ = list(base_dict.keys()) values_ = list(base_dict.values()) items_ = list(base_dict.items()) @@ -43,6 +53,88 @@ def test_dict_properties_of_sample_batches(self): del batch["c"] assert batch.deleted_keys == {"c"}, batch.deleted_keys + def test_right_zero_padding(self): + """Tests, whether right-zero-padding work properly.""" + s1 = SampleBatch({ + "a": np.array([1, 2, 3]), + "b": { + "c": np.array([4, 5, 6]) + }, + "seq_lens": [1, 2], + }) + s1.right_zero_pad(max_seq_len=5) + check( + s1, { + "a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0], + "b": { + "c": [4, 0, 0, 0, 0, 5, 6, 0, 0, 0] + }, + "seq_lens": [1, 2] + }) + + def test_concat(self): + """Tests, SampleBatches.concat() and ...concat_samples().""" + s1 = SampleBatch({ + "a": np.array([1, 2, 3]), + "b": { + "c": np.array([4, 5, 6]) + }, + }) + s2 = SampleBatch({ + "a": np.array([2, 3, 4]), + "b": { + "c": np.array([5, 6, 7]) + }, + }) + concatd = SampleBatch.concat_samples([s1, s2]) + check(concatd["a"], [1, 2, 3, 2, 3, 4]) + check(concatd["b"]["c"], [4, 5, 6, 5, 6, 7]) + check(next(concatd.rows()), {"a": 1, "b": {"c": 4}}) + + concatd_2 = s1.concat(s2) + check(concatd, concatd_2) + + def test_rows(self): + s1 = SampleBatch({ + "a": np.array([[1, 1], [2, 2], [3, 3]]), + "b": { + "c": np.array([[4, 4], [5, 5], [6, 6]]) + }, + "seq_lens": np.array([1, 2]), + }) + check( + next(s1.rows()), + { + "a": [1, 1], + "b": { + "c": [4, 4] + }, + "seq_lens": [1] + }, + ) + + def test_compression(self): + """Tests, whether compression and decompression work properly.""" + s1 = SampleBatch({ + "a": np.array([1, 2, 3, 2, 3, 4]), + "b": { + "c": np.array([4, 5, 6, 5, 6, 7]) + }, + }) + # Test, whether compressing happens in-place. + s1.compress(columns={"a", "b"}, bulk=True) + self.assertTrue(is_compressed(s1["a"])) + self.assertTrue(is_compressed(s1["b"]["c"])) + self.assertTrue(isinstance(s1["b"], dict)) + + # Test, whether de-compressing happens in-place. + s1.decompress_if_needed(columns={"a", "b"}) + check(s1["a"], [1, 2, 3, 2, 3, 4]) + check(s1["b"]["c"], [4, 5, 6, 5, 6, 7]) + it = s1.rows() + next(it) + check(next(it), {"a": 2, "b": {"c": 5}}) + if __name__ == "__main__": import pytest From bbc806bfe9306a9dc3899019eb2e15aedfdf467b Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sun, 1 Aug 2021 08:46:00 -0400 Subject: [PATCH 03/45] wip --- rllib/policy/sample_batch.py | 19 ++++-- rllib/policy/tests/test_sample_batch.py | 87 +++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index d164e66328077..a574f48c7c982 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -124,6 +124,12 @@ def __init__(self, *args, **kwargs): else: self.count = lengths[0] if lengths else 0 + # A convenience map for slicing this batch into sub-batches along + # the time axis. This helps reduce repeated iterations through the + # batch's seq_lens array to find good slicing points. Built lazily + # when needed. + self._slice_map = [] + @PublicAPI def __len__(self): """Returns the amount of samples in the sample batch.""" @@ -783,10 +789,12 @@ def _slice(self, slice_: slice): SampleBatch: A new SampleBatch, however "linking" into the same data (sliced) as self. """ - assert slice_.start >= 0 and slice_.stop >= 0 and \ - slice_.step in [1, None] + start = slice_.start or 0 + stop = slice_.stop or len(self) + assert start >= 0 and stop >= 0 and slice_.step in [1, None] + if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: - # Build our slice-map if not done already. + # Build our slice-map, if not done already. if not self._slice_map: sum_ = 0 for i, l in enumerate(self["seq_lens"]): @@ -795,8 +803,8 @@ def _slice(self, slice_: slice): sum_ += l self._slice_map.append((len(self["seq_lens"]), sum_)) - start_seq_len, start = self._slice_map[slice_.start] - stop_seq_len, stop = self._slice_map[slice_.stop] + start_seq_len, start = self._slice_map[start] + stop_seq_len, stop = self._slice_map[stop] if self.zero_padded: start = start_seq_len * self.max_seq_len stop = stop_seq_len * self.max_seq_len @@ -815,7 +823,6 @@ def map_(path, value): _time_major=self.time_major, ) else: - start, stop = slice_.start, slice_.stop data = tree.map_structure(lambda value: value[start:stop], self) return SampleBatch( data, diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index cf17e0d8d2e23..6d95d6a46274d 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -135,6 +135,93 @@ def test_compression(self): next(it) check(next(it), {"a": 2, "b": {"c": 5}}) + def test_slicing(self): + """Tests, whether slicing can be done on SampleBatches.""" + s1 = SampleBatch({ + "a": np.array([1, 2, 3, 2, 3, 4]), + "b": { + "c": np.array([4, 5, 6, 5, 6, 7]) + }, + }) + check(s1[:3], { + "a": [1, 2, 3], + "b": { + "c": [4, 5, 6] + }, + }) + check(s1[0:3], { + "a": [1, 2, 3], + "b": { + "c": [4, 5, 6] + }, + }) + check(s1[1:4], { + "a": [2, 3, 2], + "b": { + "c": [5, 6, 5] + }, + }) + check(s1[1:], { + "a": [2, 3, 2, 3, 4], + "b": { + "c": [5, 6, 5, 6, 7] + }, + }) + check(s1[3:4], { + "a": [2], + "b": { + "c": [5] + }, + }) + + # When we change the slice, the original SampleBatch should also + # change (shared underlying data). + s1[:3]["a"][0] = 100 + s1[1:2]["a"][0] = 200 + check(s1["a"][0], 100) + check(s1["a"][1], 200) + + # Seq-len batches should be auto-sliced along sequences, + # no matter what. + s2 = SampleBatch({ + "a": np.array([1, 2, 3, 2, 3, 4]), + "b": { + "c": np.array([4, 5, 6, 5, 6, 7]) + }, + "seq_lens": [2, 3, 1], + "state_in_0": [1.0, 3.0, 4.0], + }) + # We would expect a=[1, 2, 3] now, but due to the sequence + # boundary, we stop earlier. + check( + s2[:3], { + "a": [1, 2], + "b": { + "c": [4, 5] + }, + "seq_lens": [2], + "state_in_0": [1.0], + }) + # Split exactly at a seq-len boundary. + check( + s2[:5], { + "a": [1, 2, 3, 2, 3], + "b": { + "c": [4, 5, 6, 5, 6] + }, + "seq_lens": [2, 3], + "state_in_0": [1.0, 3.0], + }) + check( + s2[:], { + "a": [1, 2, 3, 2, 3, 4], + "b": { + "c": [4, 5, 6, 5, 6, 7] + }, + "seq_lens": [2, 3, 1], + "state_in_0": [1.0, 3.0, 4.0], + }) + if __name__ == "__main__": import pytest From 7b9c86e5f6e9472047e14ed4662c2d983b0cee07 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 2 Aug 2021 11:36:08 -0400 Subject: [PATCH 04/45] wip --- .../collectors/simple_list_collector.py | 58 +++++++++---- rllib/models/tf/complex_input_net.py | 12 +-- rllib/policy/dynamic_tf_policy.py | 2 +- rllib/policy/policy.py | 34 ++++---- rllib/policy/sample_batch.py | 6 +- rllib/policy/tf_policy.py | 9 +- rllib/utils/filter.py | 7 +- rllib/utils/spaces/flexdict.py | 42 ++++------ rllib/utils/spaces/space_utils.py | 66 +++++++++++++++ rllib/utils/tf_ops.py | 83 ++++++++++++------- 10 files changed, 215 insertions(+), 104 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index c1df01d922b05..d467858641127 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -3,6 +3,7 @@ import logging import math import numpy as np +import tree # pip install dm_tree from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union from ray.rllib.env.base_env import _DUMMY_AGENT_ID @@ -14,6 +15,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \ TensorType, ViewRequirementsDict from ray.util.debug import log_once @@ -58,7 +60,7 @@ def __init__(self, view_reqs): if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0) for k, vr in view_reqs.items()) # The actual data buffers (lists holding each timestep's data). - self.buffers: Dict[str, List] = {} + self.buffers = {} # The episode ID for the agent for which we collect data. self.episode_id = None # The simple timestep count for this agent. Gets increased by one @@ -80,6 +82,8 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, init_obs (TensorType): The initial observation tensor (after `env.reset()`). """ + # Seems to be the first time, we call this method. Build our + # (list-based) buffers first. if SampleBatch.OBS not in self.buffers: self._build_buffers( single_row={ @@ -88,12 +92,20 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, "env_id": env_id, "t": t, }) - self.buffers[SampleBatch.OBS].append(init_obs) + + # Append data to existing buffers. + tree.map_structure_with_path(self._add_obs_helper, init_obs) self.episode_id = episode_id self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) self.buffers["env_id"].append(env_id) self.buffers["t"].append(t) + def _add_obs_helper(self, path, value): + curr = self.buffers[SampleBatch.OBS] + for p in path[:-1]: + curr = curr[p] + curr[path[-1]].append(value) + def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ None: """Adds the given dictionary (row) of values to the Agent's trajectory. @@ -116,7 +128,10 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) - self.buffers[k].append(v) + if k == SampleBatch.OBS: + tree.map_structure_with_path(self._add_obs_helper, v) + else: + self.buffers[k].append(v) self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: @@ -282,7 +297,8 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: "env_id", "t" ] else 0) # Python primitive, tensor, or dict (e.g. INFOs). - self.buffers[col] = [data for _ in range(shift)] + self.buffers[col] = tree.map_structure( + lambda v: [v for _ in range(shift)], data) class _PolicyCollector: @@ -556,29 +572,43 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. else: time_indices = view_req.shift + delta - data_list = [] - # Loop through agents and add-up their data (batch). + + # Loop through agents and add up their data (batch). + data = [[] for _ in range(len(buffers[keys[0]][data_col]))] for k in keys: if data_col == SampleBatch.EPS_ID: - data_list.append(self.agent_collectors[k].episode_id) + data[0].append(self.agent_collectors[k].episode_id) else: + # Buffer for the data does not exist yet: Create dummy + # (zero) data. if data_col not in buffers[k]: - fill_value = np.zeros_like(view_req.space.sample()) \ + fill_value = get_dummy_batch_for_space(view_req.space, batch_size=0) \ if isinstance(view_req.space, Space) else \ view_req.space self.agent_collectors[k]._build_buffers({ data_col: fill_value }) + + # `shift_from` and `shift_to` are defined: User wants a + # view with some time-range. if isinstance(time_indices, tuple): + # `shift_to` == -1: Until the end (including(!) the last + # item). if time_indices[1] == -1: - data_list.append( - buffers[k][data_col][time_indices[0]:]) + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices[0]:]) + # `shift_to` != -1: "Normal" range. else: - data_list.append(buffers[k][data_col][time_indices[ - 0]:time_indices[1] + 1]) + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices[ + 0]:time_indices[1] + 1]) + # Single index. else: - data_list.append(buffers[k][data_col][time_indices]) - input_dict[view_col] = np.array(data_list) + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices]) + + data = [np.array(d) for d in data] + input_dict[view_col] = tree.unflatten_as(, data) self._reset_inference_calls(policy_id) diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index cddfc93510512..d16adc4c2c71a 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -1,6 +1,6 @@ from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple import numpy as np -import tree +import tree # pip install dm_tree from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions @@ -133,11 +133,11 @@ def forward(self, input_dict, state, seq_lens): cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: - if component.dtype in [tf.int32, tf.int64, tf.uint8]: - outs.append( - one_hot(component, self.flattened_input_space[i])) - else: - outs.append(component) + #if component.dtype in [tf.int32, tf.int64, tf.uint8]: + outs.append( + one_hot(component, self.flattened_input_space[i])) + #else: + # outs.append(tf.cast(component, tf.float32)) else: outs.append(tf.cast( tf.reshape(component, [-1, self.flatten[i]]), diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 10de14e50d0e7..03e8c74327e1a 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -548,7 +548,7 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, if view_req.used_for_training: # Create a +time-axis placeholder if the shift is not an # int (range or list of ints). - flatten = view_col != SampleBatch.OBS or \ + flatten = view_col not in [SampleBatch.OBS, SampleBatch.NEXT_OBS] or \ self.config["preprocessor_pref"] is not None input_dict[view_col] = get_placeholder( space=view_req.space, diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index ce211387b638a..2b3629e2fca8c 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -15,7 +15,8 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.spaces.space_utils import clip_action, \ - get_base_struct_from_space, unbatch, unsquash_action + get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \ + unsquash_action from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ TensorType, TrainerConfigDict, Tuple, Union @@ -292,7 +293,7 @@ def compute_single_action( @DeveloperAPI def compute_actions_from_input_dict( self, - input_dict: Dict[str, TensorType], + input_dict: SampleBatch, explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, @@ -304,10 +305,10 @@ def compute_actions_from_input_dict( to construct the input_dict for the Model. Args: - input_dict (Dict[str, TensorType]): An input dict mapping str - keys to Tensors. `input_dict` already abides to the Policy's - as well as the Model's view requirements and can be passed - to the Model as-is. + input_dict (SampleBatch): A SampleBatch containing the Tensors + to compute actions. `input_dict` already abides to the + Policy's as well as the Model's view requirements and can + thus be passed to the Model as-is. explore (bool): Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). timestep (Optional[int]): The current (sampling) time step. @@ -852,7 +853,8 @@ def _get_dummy_batch_from_view_requirements( """ ret = {} for view_col, view_req in self.view_requirements.items(): - if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)): + if self.config["preprocessor_pref"] is not None and \ + isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)): _, shape = ModelCatalog.get_action_shape( view_req.space, framework=self.config["framework"]) ret[view_col] = \ @@ -860,23 +862,15 @@ def _get_dummy_batch_from_view_requirements( else: # Range of indices on time-axis, e.g. "-50:-1". if view_req.shift_from is not None: - ret[view_col] = np.zeros_like([[ - view_req.space.sample() - for _ in range(view_req.shift_to - - view_req.shift_from + 1) - ] for _ in range(batch_size)]) - # Set of (probably non-consecutive) indices. + ret[view_col] = get_dummy_batch_for_space(view_req.space, batch_size, time_size=view_req.shift_to - view_req.shift_from + 1) + # Sequence of (probably non-consecutive) indices. elif isinstance(view_req.shift, (list, tuple)): - ret[view_col] = np.zeros_like([[ - view_req.space.sample() - for t in range(len(view_req.shift)) - ] for _ in range(batch_size)]) + ret[view_col] = get_dummy_batch_for_space(view_req.space, batch_size, time_size=len(view_req.shift)) # Single shift int value. else: if isinstance(view_req.space, gym.spaces.Space): - ret[view_col] = np.zeros_like([ - view_req.space.sample() for _ in range(batch_size) - ]) + ret[view_col] = get_dummy_batch_for_space( + view_req.space, batch_size, fill_value=0.0) else: ret[view_col] = [ view_req.space for _ in range(batch_size) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index fc317ebab0eec..e0aa3a33f8229 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -937,8 +937,10 @@ def get_single_step_input_dict(self, view_requirements, index="last"): ]) # Single index. else: - data = self[data_col][-1] - input_dict[view_col] = np.array([data]) + input_dict[view_col] = tree.map_structure( + lambda v: v[-1:], # keep as array (w/ 1 element) + self[data_col], + ) else: # Index range. if isinstance(index, tuple): diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index ab6a7f8047fcb..ba8fb634705d9 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -3,6 +3,7 @@ import logging import numpy as np import os +import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING import ray @@ -382,7 +383,7 @@ def compute_actions( @override(Policy) def compute_actions_from_input_dict( self, - input_dict: Dict[str, TensorType], + input_dict: SampleBatch, explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, @@ -403,6 +404,7 @@ def compute_actions_from_input_dict( # Update our global timestep by the batch size. self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \ + else len(input_dict) if isinstance(input_dict, SampleBatch) \ else obs_batch.shape[0] return fetched @@ -860,7 +862,10 @@ def _build_compute_actions(self, if hasattr(self, "_input_dict"): for key, value in input_dict.items(): if key in self._input_dict: - builder.add_feed_dict({self._input_dict[key]: value}) + # Handle complex/nested spaces as well. + tree.map_structure( + lambda k, v: builder.add_feed_dict({k: v}), self._input_dict[key], value + ) # For policies that inherit directly from TFPolicy. else: builder.add_feed_dict({ diff --git a/rllib/utils/filter.py b/rllib/utils/filter.py index 683503f30131e..e7e9f1b9bbb01 100644 --- a/rllib/utils/filter.py +++ b/rllib/utils/filter.py @@ -35,10 +35,11 @@ def as_serializable(self): class NoFilter(Filter): is_concurrent = True - def __init__(self, *args): - pass - def __call__(self, x, update=True): + # Process no further if already np.ndarray, dict, or tuple. + if isinstance(x, (np.ndarray, dict, tuple)): + return x + try: return np.asarray(x) except Exception: diff --git a/rllib/utils/spaces/flexdict.py b/rllib/utils/spaces/flexdict.py index de84ebe279cc9..98ab6203d82c6 100644 --- a/rllib/utils/spaces/flexdict.py +++ b/rllib/utils/spaces/flexdict.py @@ -1,49 +1,37 @@ import gym -from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI @PublicAPI class FlexDict(gym.spaces.Dict): - """Gym Dictionary with arbitrary keys updatable after instantiation + """Gym Dict with arbitrary keys updatable after instantiation. + + Adds the __setitem__ method so new keys can be inserted after + instantiation. See also: documentation for gym.spaces.Dict Example: - space = FlexDict({}) - space['key'] = spaces.Box(4,) - See also: documentation for gym.spaces.Dict + >>> space = FlexDict({}) + >>> space["key"] = spaces.Box(4,) """ - def __init__(self, spaces=None, **spaces_kwargs): - err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" - assert (spaces is None) or (not spaces_kwargs), err - - if spaces is None: - spaces = spaces_kwargs - - self.spaces = spaces - for space in spaces.values(): - self.assertSpace(space) - - # None for shape and dtype, since it'll require special handling - self.np_random = None - self.shape = None - self.dtype = None - self.seed() - - def assertSpace(self, space): + def assert_space(self, space): err = "Values of the dict should be instances of gym.Space" assert issubclass(type(space), gym.spaces.Space), err + @override(gym.spaces.Dict) def sample(self): return {k: space.sample() for k, space in self.spaces.items()} - def __getitem__(self, key): - return self.spaces[key] - def __setitem__(self, key, space): - self.assertSpace(space) + self.assert_space(space) self.spaces[key] = space + @override(gym.spaces.Dict) def __repr__(self): return "FlexDict(" + ", ".join( [str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")" + + @override(gym.spaces.Dict) + def __eq__(self, other): + return isinstance(other, FlexDict) and self.spaces == other.spaces diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index 9577610fde851..4ca507943a3e5 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -2,6 +2,7 @@ from gym.spaces import Tuple, Dict import numpy as np import tree # pip install dm_tree +from typing import Optional, Union def flatten_space(space): @@ -64,6 +65,71 @@ def _helper_struct(space_): return _helper_struct(space) +def get_dummy_batch_for_space(space: gym.Space, + batch_size: int = 32, + fill_value: Union[float, int, str] = 0.0, + time_size: Optional[int] = None, + time_major: bool = False, + ) -> np.ndarray: + """Returns batched dummy data (using `batch_size`) for the given `space`. + + Note: The returned batch will not pass a `space.contains(batch)` test + as an additional batch dimension has to be added as dim=0. + + Args: + space (gym.Space): The space to get a dummy batch for. + batch_size(int): The required batch size (B). Note that this can also + be 0 (only if `time_size` is None!), which will result in a + non-batched sample for the given space (no batch dim). + fill_value (Union[float, int, str]): The value to fill the batch with + or "random" for random values. + time_size (Optional[int]): If not None, add an optional time axis + of `time_size` size to the returned batch. + time_major (bool): If True AND `time_size` is not None, return batch + as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size` + if None, ignore this setting and return [B x ...]. + + Returns: + The dummy batch of size `bqtch_size` matching the given space. + """ + # Complex spaces. Perform recursive calls of this function. + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): + return tree.map_structure( + lambda s: get_dummy_batch_for_space(s, batch_size, fill_value), + get_base_struct_from_space(space), + ) + # Primivite spaces: Box, Discrete, MultiDiscrete. + # Random values: Use gym's sample() method. + elif fill_value == "random": + if time_size is not None: + assert batch_size > 0 and time_size > 0 + if time_major: + return np.array([ + [space.sample() for _ in range(batch_size)] + for t in range(time_size) + ], dtype=space.dtype) + else: + return np.array([ + [space.sample() for t in range(time_size)] + for _ in range(batch_size) + ], dtype=space.dtype) + else: + return np.array( + [space.sample() for _ in range(batch_size)] if batch_size > 0 else space.sample(), + dtype=space.dtype) + # Fill value given: Use np.full. + else: + if time_size is not None: + assert batch_size > 0 and time_size > 0 + if time_major: + shape = [time_size, batch_size] + else: + shape = [batch_size, time_size] + else: + shape = [batch_size] if batch_size > 0 else [] + return np.full(shape + list(space.shape), fill_value=fill_value, dtype=space.dtype) + + def flatten_to_single_ndarray(input_): """Returns a single np.ndarray given a list/tuple of np.ndarrays. diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 6bfad17fc9d36..a7bc4ca52cabc 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -180,8 +180,8 @@ def make_tf_callable(session_or_none, dynamic_shape=False): def make_wrapper(fn): # Static-graph mode: Create placeholders and make a session call each - # time the wrapped function is called. Return this session call's - # outputs. + # time the wrapped function is called. Returns the output of this + # session call. if session_or_none is not None: args_placeholders = [] kwargs_placeholders = {} @@ -195,40 +195,65 @@ def call(*args, **kwargs): else: args_flat.append(a) args = args_flat + + # We have not built any placeholders yet: Do this once here, then + # reuse the same placeholders each time we call this function + # again. if symbolic_out[0] is None: with session_or_none.graph.as_default(): - for i, v in enumerate(args): - if dynamic_shape: - if len(v.shape) > 0: - shape = (None, ) + v.shape[1:] - else: - shape = () - else: - shape = v.shape - args_placeholders.append( - tf1.placeholder( - dtype=v.dtype, - shape=shape, - name="arg_{}".format(i))) - for k, v in kwargs.items(): + + def _create_placeholders(path, value): if dynamic_shape: - if len(v.shape) > 0: - shape = (None, ) + v.shape[1:] + if len(value.shape) > 0: + shape = (None, ) + value.shape[1:] else: shape = () else: - shape = v.shape - kwargs_placeholders[k] = \ - tf1.placeholder( - dtype=v.dtype, - shape=shape, - name="kwarg_{}".format(k)) - symbolic_out[0] = fn(*args_placeholders, - **kwargs_placeholders) + shape = value.shape + return tf1.placeholder( + dtype=value.dtype, + shape=shape, + name=".".join([str(p) for p in path]), + ) + + args_placeholders = tree.map_structure_with_path( + _create_placeholders, args) + #for i, v in enumerate(args): + # if dynamic_shape: + # if len(v.shape) > 0: + # shape = (None, ) + v.shape[1:] + # else: + # shape = () + # else: + # shape = v.shape + # args_placeholders.append( + # tf1.placeholder( + # dtype=v.dtype, + # shape=shape, + # name="arg_{}".format(i))) + + kwargs_placeholders = tree.map_structure_with_path( + _create_placeholders, kwargs) + + #for k, v in kwargs.items(): + # if dynamic_shape: + # if len(v.shape) > 0: + # shape = (None, ) + v.shape[1:] + # else: + # shape = () + # else: + # shape = v.shape + # kwargs_placeholders[k] = \ + # tf1.placeholder( + # dtype=v.dtype, + # shape=shape, + # name="kwarg_{}".format(k)) + symbolic_out[0] = fn( + *args_placeholders, **kwargs_placeholders) feed_dict = dict(zip(args_placeholders, args)) - feed_dict.update( - {kwargs_placeholders[k]: kwargs[k] - for k in kwargs.keys()}) + tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), kwargs_placeholders, kwargs) + #{kwargs_placeholders[k]: kwargs[k] + # for k in kwargs.keys()}) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret From 1b52cd84e8caf0e036887b25ceec864d2e0edefa Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 2 Aug 2021 15:36:41 -0400 Subject: [PATCH 05/45] wip. --- rllib/policy/sample_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a574f48c7c982..9fb10b9b91bdc 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -104,7 +104,7 @@ def __init__(self, *args, **kwargs): # TODO: Drop support for lists as values. # Convert lists of int|float into numpy arrays make sure all data # has same length. - if isinstance(v, list) and isinstance(v[0], (int, float)): + if isinstance(v, list): self[k] = np.array(v) # Try to infer the "length" of the SampleBatch by finding the first From 674fb232e5f2b57b851c768a2dec1d8d8fb6e8a7 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 2 Aug 2021 15:39:57 -0400 Subject: [PATCH 06/45] wip. --- rllib/tests/test_rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/tests/test_rollout.py b/rllib/tests/test_rollout.py index 53f6e8773389d..139d88b80df4e 100644 --- a/rllib/tests/test_rollout.py +++ b/rllib/tests/test_rollout.py @@ -37,7 +37,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False): ", \"timesteps_per_iteration\": 5,\"min_iter_time_s\": 0.1, " "\"model\": {\"fcnet_hiddens\": [10]}" "}' --stop='{\"training_iteration\": 1}'" + - " --env={}".format(env)) + " --env={} --no-ray-ui".format(env)) checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/" "checkpoint-1".format(tmp_dir)).read()[:-1] From 88c8e95bc6be787801da46bec3a168cd9892e1f1 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 2 Aug 2021 20:49:56 -0400 Subject: [PATCH 07/45] fix. --- rllib/policy/sample_batch.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 9fb10b9b91bdc..e284112cefe36 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -742,15 +742,14 @@ def _decompress_in_place(path, value): if path[0] not in columns: return curr = self - for i, p in enumerate(path): - if i == len(path) - 1: - # Bulk compressed. - if is_compressed(value): - curr[p] = unpack(value) - # Non bulk compressed. - elif len(value) > 0 and is_compressed(value): - curr[p] = np.array([unpack(o) for o in value]) + for p in path[:-1]: curr = curr[p] + # Bulk compressed. + if is_compressed(value): + curr[path[-1]] = unpack(value) + # Non bulk compressed. + elif len(value) > 0 and is_compressed(value[0]): + curr[path[-1]] = np.array([unpack(o) for o in value]) tree.map_structure_with_path(_decompress_in_place, self) From 68e81feb05534681572148369cc7a842920a29c9 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 4 Aug 2021 12:05:37 -0400 Subject: [PATCH 08/45] wip. --- rllib/policy/sample_batch.py | 18 ++++++-------- rllib/policy/tests/test_sample_batch.py | 33 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index d5bda92b80bca..99d8dfae39391 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -219,25 +219,21 @@ def concat(self, other: "SampleBatch") -> "SampleBatch": @PublicAPI def copy(self, shallow: bool = False) -> "SampleBatch": - """Creates a (deep) copy of this SampleBatch and returns it. + """Creates a deep or shallow copy of this SampleBatch and returns it. Args: shallow (bool): Whether the copying should be done shallowly. Returns: - SampleBatch: A (deep) copy of this SampleBatch object. + SampleBatch: A deep or shallow copy of this SampleBatch object. """ + copy_ = {k: v for k, v in self.items()} data = tree.map_structure( - lambda v: (np.array(v, copy=not shallow) - if isinstance(v, np.ndarray) else v), - self) - - copy_ = SampleBatch( - data, - _time_major=self.time_major, - _zero_padded=self.zero_padded, - _max_seq_len=self.max_seq_len, + lambda v: (np.array(v, copy=not shallow) if + isinstance(v, np.ndarray) else v), + copy_, ) + copy_ = SampleBatch(data) copy_.set_get_interceptor(self.get_interceptor) return copy_ diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index 6d95d6a46274d..cc158fab1a97c 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -222,6 +222,39 @@ def test_slicing(self): "state_in_0": [1.0, 3.0, 4.0], }) + def test_copy(self): + s = SampleBatch({ + "a": np.array([1, 2, 3, 2, 3, 4]), + "b": { + "c": np.array([4, 5, 6, 5, 6, 7]) + }, + "seq_lens": [2, 3, 1], + "state_in_0": [1.0, 3.0, 4.0], + }) + s_copy = s.copy(shallow=False) + s_copy["a"][0] = 100 + s_copy["b"]["c"][0] = 200 + s_copy["seq_lens"][0] = 3 + s_copy["seq_lens"][1] = 2 + s_copy["state_in_0"][0] = 400.0 + self.assertNotEqual(s["a"][0], s_copy["a"][0]) + self.assertNotEqual(s["b"]["c"][0], s_copy["b"]["c"][0]) + self.assertNotEqual(s["seq_lens"][0], s_copy["seq_lens"][0]) + self.assertNotEqual(s["seq_lens"][1], s_copy["seq_lens"][1]) + self.assertNotEqual(s["state_in_0"][0], s_copy["state_in_0"][0]) + + s_copy = s.copy(shallow=True) + s_copy["a"][0] = 100 + s_copy["b"]["c"][0] = 200 + s_copy["seq_lens"][0] = 3 + s_copy["seq_lens"][1] = 2 + s_copy["state_in_0"][0] = 400.0 + self.assertEqual(s["a"][0], s_copy["a"][0]) + self.assertEqual(s["b"]["c"][0], s_copy["b"]["c"][0]) + self.assertEqual(s["seq_lens"][0], s_copy["seq_lens"][0]) + self.assertEqual(s["seq_lens"][1], s_copy["seq_lens"][1]) + self.assertEqual(s["state_in_0"][0], s_copy["state_in_0"][0]) + if __name__ == "__main__": import pytest From 90b3735c3147b5bc43cc3217d6b3551f49740ec0 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 5 Aug 2021 21:27:10 -0400 Subject: [PATCH 09/45] wip. --- rllib/policy/sample_batch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 99d8dfae39391..d5e92505aed89 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -263,10 +263,12 @@ def rows(self) -> Dict[str, TensorType]: # Do we add seq_lens=[1] to each row? seq_lens = None if self.get("seq_lens") is None else np.array([1]) + self_as_dict = {k: v for k, v in self.items()} + for i in range(self.count): yield tree.map_structure_with_path( lambda p, v: v[i] if p[0] != "seq_lens" else seq_lens, - self, + self_as_dict, ) @PublicAPI From 1cc1c8767611b25f32e694b4e3e9204d133bd39f Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 6 Aug 2021 04:50:35 -0400 Subject: [PATCH 10/45] wip. --- rllib/policy/rnn_sequencing.py | 3 ++- rllib/policy/sample_batch.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index ad04166a2d7e4..22154d7eb78ec 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -431,6 +431,7 @@ def timeslice_along_seq_lens_with_overlap( # Zero-pad each slice if necessary. if zero_pad_max_seq_len > 0: for ts in timeslices: - ts.zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True) + ts.right_zero_pad( + max_seq_len=zero_pad_max_seq_len, exclude_states=True) return timeslices diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index d5e92505aed89..183925e39d9d9 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -596,7 +596,8 @@ def _zero_pad_in_place(path, value): curr[p] = f_pad curr = curr[p] - tree.map_structure_with_path(_zero_pad_in_place, self) + self_as_dict = {k: v for k, v in self.items()} + tree.map_structure_with_path(_zero_pad_in_place, self_as_dict) # Set flags to indicate, we are now zero-padded (and to what extend). self.zero_padded = True From 4ebcdad6ed4c42a83c59eab444ab59e8d94f0f59 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 6 Aug 2021 11:24:14 -0400 Subject: [PATCH 11/45] wip. --- rllib/policy/sample_batch.py | 38 ++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 183925e39d9d9..12ddab64de0a2 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -161,18 +161,20 @@ def concat_samples( """ if isinstance(samples[0], MultiAgentBatch): return MultiAgentBatch.concat_samples(samples) - seq_lens = [] + concatd_seq_lens = [] concat_samples = [] zero_padded = samples[0].zero_padded max_seq_len = samples[0].max_seq_len + time_major = samples[0].time_major for s in samples: if s.count > 0: assert s.zero_padded == zero_padded + assert s.time_major == time_major if zero_padded: assert s.max_seq_len == max_seq_len concat_samples.append(s) if s.get("seq_lens") is not None: - seq_lens.extend(s["seq_lens"]) + concatd_seq_lens.extend(s["seq_lens"]) # If we don't have any samples (no or only empty SampleBatches), # return an empty SampleBatch here. @@ -180,19 +182,29 @@ def concat_samples( return SampleBatch() # Collect the concat'd data. + concatd_data = {} + + def concat_key(*values): + return concat_aligned(values, time_major) + try: - concatd_data = tree.map_structure( - lambda *s: concat_aligned(s, concat_samples[0].time_major), - *concat_samples) - except TypeError: - raise TypeError(f"Cannot concat `samples` ({samples})! " - "Structures don't match.") + for k in concat_samples[0].keys(): + if k == "infos": + concatd_data[k] = concat_aligned( + [s[k] for s in concat_samples], time_major=time_major) + else: + concatd_data[k] = tree.map_structure( + concat_key, *[c[k] for c in concat_samples]) + except Exception as e: + raise ValueError(f"Cannot concat data under key '{k}', b/c " + "sub-structures under that key don't match. " + f"`samples`={samples}") # Return a new (concat'd) SampleBatch. return SampleBatch( concatd_data, - seq_lens=seq_lens, - _time_major=concat_samples[0].time_major, + seq_lens=concatd_seq_lens, + _time_major=time_major, _zero_padded=zero_padded, _max_seq_len=max_seq_len, ) @@ -387,17 +399,13 @@ def slice(self, start: int, end: int, state_start=None, """Returns a slice of the row data of this batch (w/o copying). Args: - start (int): Starting index. If < 0, will zero-pad. + start (int): Starting index. If < 0, will left-zero-pad. end (int): Ending index. Returns: SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ - if log_once("SampleBatch.slice"): - deprecation_warning( - "SampleBatch.slice()", "SampleBatch[start:stop]", error=False) - if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: if start < 0: data = { From 40bb3d3cd4088d37a0b6013e703c5ec34e0ee8ed Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 6 Aug 2021 12:32:33 -0400 Subject: [PATCH 12/45] wip. --- rllib/policy/sample_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 12ddab64de0a2..44601bf7b1756 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -195,7 +195,7 @@ def concat_key(*values): else: concatd_data[k] = tree.map_structure( concat_key, *[c[k] for c in concat_samples]) - except Exception as e: + except Exception: raise ValueError(f"Cannot concat data under key '{k}', b/c " "sub-structures under that key don't match. " f"`samples`={samples}") From 8aa30159672730762e71a527943981dfd79031a0 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 6 Aug 2021 15:06:01 -0400 Subject: [PATCH 13/45] wip. --- rllib/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/BUILD b/rllib/BUILD index 81f031d990083..4fe86622d581c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -567,7 +567,7 @@ py_test( py_test( name = "test_trainer", tags = ["agents_dir"], - size = "medium", + size = "large", srcs = ["agents/tests/test_trainer.py"] ) From 9f787dd1cfeff7f63e1b5591f29452ee6f72ddf4 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 6 Aug 2021 20:52:58 -0400 Subject: [PATCH 14/45] wip. --- .../collectors/simple_list_collector.py | 197 +++++++++++------- rllib/evaluation/sampler.py | 7 +- rllib/examples/custom_observation_filters.py | 4 +- rllib/examples/env/random_env.py | 4 +- rllib/examples/preprocessing_disabled.py | 3 +- rllib/models/catalog.py | 12 +- rllib/models/preprocessors.py | 2 +- rllib/models/tf/complex_input_net.py | 12 +- rllib/policy/dynamic_tf_policy.py | 6 +- rllib/policy/policy.py | 10 +- rllib/policy/tf_policy.py | 8 +- rllib/utils/spaces/space_utils.py | 37 ++-- rllib/utils/tf_ops.py | 62 +++--- 13 files changed, 206 insertions(+), 158 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index d467858641127..f69a16325ecad 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -61,6 +61,7 @@ def __init__(self, view_reqs): for k, vr in view_reqs.items()) # The actual data buffers (lists holding each timestep's data). self.buffers = {} + self.buffer_structs = {} # The episode ID for the agent for which we collect data. self.episode_id = None # The simple timestep count for this agent. Gets increased by one @@ -94,17 +95,13 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, }) # Append data to existing buffers. - tree.map_structure_with_path(self._add_obs_helper, init_obs) + flattened = tree.flatten(init_obs) + for i, sub_obs in enumerate(flattened): + self.buffers[SampleBatch.OBS][i].append(sub_obs) self.episode_id = episode_id - self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) - self.buffers["env_id"].append(env_id) - self.buffers["t"].append(t) - - def _add_obs_helper(self, path, value): - curr = self.buffers[SampleBatch.OBS] - for p in path[:-1]: - curr = curr[p] - curr[path[-1]].append(value) + self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index) + self.buffers["env_id"][0].append(env_id) + self.buffers["t"][0].append(t) def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ None: @@ -128,10 +125,12 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) - if k == SampleBatch.OBS: - tree.map_structure_with_path(self._add_obs_helper, v) - else: - self.buffers[k].append(v) + #if k == SampleBatch.OBS: + # tree.map_structure_with_path(self._add_obs_helper, v) + #else: + flattened = tree.flatten(v) + for i, sub_list in enumerate(self.buffers[k]): + sub_list.append(flattened[i]) self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: @@ -172,7 +171,9 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # Keep an np-array cache so we don't have to regenerate the # np-array for different view_cols using to the same data_col. if data_col not in np_data: - np_data[data_col] = to_float_np_array(self.buffers[data_col]) + np_data[data_col] = [ + to_float_np_array(d) for d in self.buffers[data_col] + ] # Range of indices on time-axis, e.g. "-50:-1". Together with # the `batch_repeat_value`, this determines the data produced. @@ -186,40 +187,48 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # every n timesteps. if view_req.batch_repeat_value > 1: count = int( - math.ceil((len(np_data[data_col]) - self.shift_before) - / view_req.batch_repeat_value)) - data = np.asarray([ - np_data[data_col][self.shift_before + - (i * view_req.batch_repeat_value) + - view_req.shift_from + - obs_shift:self.shift_before + - (i * view_req.batch_repeat_value) + - view_req.shift_to + 1 + obs_shift] - for i in range(count) - ]) + math.ceil( + (len(np_data[data_col][0]) - self.shift_before) / + view_req.batch_repeat_value)) + data = [ + np.asarray([ + d[self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_from + + obs_shift:self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_to + 1 + obs_shift] + for i in range(count) + ]) for d in np_data[data_col] + ] # Batch repeat value = 1: Repeat the shift_from/to range at # each timestep. else: - d = np_data[data_col] + d0 = np_data[data_col][0] shift_win = view_req.shift_to - view_req.shift_from + 1 - data_size = d.itemsize * int(np.product(d.shape[1:])) + data_size = d0.itemsize * int(np.product(d0.shape[1:])) strides = [ - d.itemsize * int(np.product(d.shape[i + 1:])) - for i in range(1, len(d.shape)) + d0.itemsize * int(np.product(d0.shape[i + 1:])) + for i in range(1, len(d0.shape)) + ] + data = [ + np.lib.stride_tricks.as_strided( + d[self.shift_before - shift_win:], + [self.agent_steps, shift_win + ] + [d.shape[i] for i in range(1, len(d.shape))], + [data_size, data_size] + strides) + for d in np_data[data_col] ] - data = np.lib.stride_tricks.as_strided( - d[self.shift_before - shift_win:], - [self.agent_steps, shift_win - ] + [d.shape[i] for i in range(1, len(d.shape))], - [data_size, data_size] + strides) # Set of (probably non-consecutive) indices. # Example: # shift=[-3, 0] # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...] elif isinstance(view_req.shift, np.ndarray): - data = np_data[data_col][self.shift_before + obs_shift + - view_req.shift] + data = [ + d[self.shift_before + obs_shift + view_req.shift] + for d in np_data[data_col] + ] # Single shift int value. Use the trajectory as-is, and if # `shift` != 0: shifted by that value. else: @@ -228,16 +237,19 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # Batch repeat (only provide a value every n timesteps). if view_req.batch_repeat_value > 1: count = int( - math.ceil((len(np_data[data_col]) - self.shift_before) - / view_req.batch_repeat_value)) - data = np.asarray([ - np_data[data_col][self.shift_before + ( - i * view_req.batch_repeat_value) + shift] - for i in range(count) - ]) + math.ceil( + (len(np_data[data_col][0]) - self.shift_before) / + view_req.batch_repeat_value)) + data = [ + np.asarray([ + d[self.shift_before + + (i * view_req.batch_repeat_value) + shift] + for i in range(count) + ]) for d in np_data[data_col] + ] # Shift is exactly 0: Use trajectory as is. elif shift == 0: - data = np_data[data_col][self.shift_before:] + data = [d[self.shift_before:] for d in np_data[data_col]] # Shift is positive: We still need to 0-pad at the end. elif shift > 0: data = to_float_np_array( @@ -250,10 +262,17 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # Shift is negative: Shift into the already existing and # 0-padded "before" area of our buffers. else: - data = np_data[data_col][self.shift_before + shift:shift] + data = [ + d[self.shift_before + shift:shift] + for d in np_data[data_col] + ] if len(data) > 0: - batch_data[view_col] = data + if data_col not in self.buffer_structs: + batch_data[view_col] = data[0] + else: + batch_data[view_col] = tree.unflatten_as( + self.buffer_structs[data_col], data) # Due to possible batch-repeats > 1, columns in the resulting batch # may not all have the same batch size. @@ -296,9 +315,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, "env_id", "t" ] else 0) - # Python primitive, tensor, or dict (e.g. INFOs). - self.buffers[col] = tree.map_structure( - lambda v: [v for _ in range(shift)], data) + # Store all data as flattened lists. + self.buffers[col] = [[v for _ in range(shift)] + for v in tree.flatten(data)] + self.buffer_structs[col] = data class _PolicyCollector: @@ -316,7 +336,9 @@ def __init__(self, policy: Policy): policy (Policy): The policy object. """ - self.buffers: Dict[str, List] = collections.defaultdict(list) + #self.buffers: Dict[str, Any] = {} + #collections.defaultdict(list) + self.batches = [] self.policy = policy # The total timestep count for all agents that use this policy. # NOTE: This is not an env-step count (across n agents). AgentA and @@ -324,7 +346,7 @@ def __init__(self, policy: Policy): # doing n steps would increase the count by 2*n. self.agent_steps = 0 # Seq-lens list of already added agent batches. - self.seq_lens = [] if policy.is_recurrent() else None + #self.seq_lens = [] if policy.is_recurrent() else None def add_postprocessed_batch_for_training( self, batch: SampleBatch, @@ -339,22 +361,27 @@ def add_postprocessed_batch_for_training( view-column needs to be copied at all (not needed for training). """ - for view_col, data in batch.items(): - # 1) If col is not in view_requirements, we must have a direct - # child of the base Policy that doesn't do auto-view req creation. - # 2) Col is in view-reqs and needed for training. - view_req = view_requirements.get(view_col) - if view_req is None or view_req.used_for_training: - self.buffers[view_col].extend(data) + #for view_col, data in batch.items(): + # 1) If col is not in view_requirements, we must have a direct + # child of the base Policy that doesn't do auto-view req creation. + # 2) Col is in view-reqs and needed for training. + # view_req = view_requirements.get(view_col) + # if view_req is None or view_req.used_for_training: + # self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count + # And remove columns not needed for training. + for view_col, view_req in view_requirements.items(): + if view_col in batch and not view_req.used_for_training: + del batch[view_col] + self.batches.append(batch) # Adjust the seq-lens array depending on the incoming agent sequences. - if self.seq_lens is not None: - max_seq_len = self.policy.config["model"]["max_seq_len"] - count = batch.count - while count > 0: - self.seq_lens.append(min(count, max_seq_len)) - count -= max_seq_len + #if self.seq_lens is not None: + # max_seq_len = self.policy.config["model"]["max_seq_len"] + # count = batch.count + # while count > 0: + # self.seq_lens.append(min(count, max_seq_len)) + # count -= max_seq_len def build(self): """Builds a SampleBatch for this policy from the collected data. @@ -366,13 +393,17 @@ def build(self): this policy. """ # Create batch from our buffers. - batch = SampleBatch(self.buffers, seq_lens=self.seq_lens) + #batch = SampleBatch({ + # k: tree.unflatten_as(, v) for k, v in self.buffers.items() + #}, seq_lens=self.seq_lens) + batch = SampleBatch.concat_samples(self.batches) # Clear buffers for future samples. - self.buffers.clear() + #self.buffers.clear() + self.batches = [] # Reset agent steps to 0 and seq-lens to empty list. self.agent_steps = 0 - if self.seq_lens is not None: - self.seq_lens = [] + #if self.seq_lens is not None: + # self.seq_lens = [] return batch @@ -551,7 +582,12 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ Dict[str, TensorType]: policy = self.policy_map[policy_id] keys = self.forward_pass_agent_keys[policy_id] - buffers = {k: self.agent_collectors[k].buffers for k in keys} + + buffers = {} + for k in keys: + collector = self.agent_collectors[k] + buffers[k] = collector.buffers + buffer_structs = self.agent_collectors[keys[0]].buffer_structs input_dict = {} for view_col, view_req in policy.view_requirements.items(): @@ -574,9 +610,11 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ time_indices = view_req.shift + delta # Loop through agents and add up their data (batch). - data = [[] for _ in range(len(buffers[keys[0]][data_col]))] + data = None for k in keys: if data_col == SampleBatch.EPS_ID: + if data is None: + data = [[]] data[0].append(self.agent_collectors[k].episode_id) else: # Buffer for the data does not exist yet: Create dummy @@ -589,6 +627,11 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ data_col: fill_value }) + if data is None: + data = [ + [] for _ in range(len(buffers[keys[0]][data_col])) + ] + # `shift_from` and `shift_to` are defined: User wants a # view with some time-range. if isinstance(time_indices, tuple): @@ -600,15 +643,19 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # `shift_to` != -1: "Normal" range. else: for d, b in zip(data, buffers[k][data_col]): - d.append(b[time_indices[ - 0]:time_indices[1] + 1]) + d.append( + b[time_indices[0]:time_indices[1] + 1]) # Single index. else: for d, b in zip(data, buffers[k][data_col]): d.append(b[time_indices]) - data = [np.array(d) for d in data] - input_dict[view_col] = tree.unflatten_as(, data) + np_data = [np.array(d) for d in data] + if data_col in buffer_structs: + input_dict[view_col] = tree.unflatten_as( + buffer_structs[data_col], np_data) + else: + input_dict[view_col] = np_data[0] self._reset_inference_calls(policy_id) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 4f278f388c752..9ce88d9e9bfb5 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -820,7 +820,8 @@ def _process_observations( if preprocessor is not None: prep_obs = preprocessor.transform(raw_obs) if log_once("prep_obs"): - logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) + logger.info("Preprocessed obs: {}".format( + summarize(prep_obs))) filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)(prep_obs) if log_once("filtered_obs"): @@ -958,8 +959,8 @@ def _process_observations( # types: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): policy_id: PolicyID = new_episode.policy_for(agent_id) - preproccessor = _get_or_raise( - worker.preprocessors, policy_id) + preproccessor = _get_or_raise(worker.preprocessors, + policy_id) prep_obs: EnvObsType = raw_obs if preproccessor is not None: diff --git a/rllib/examples/custom_observation_filters.py b/rllib/examples/custom_observation_filters.py index 24f4323066b77..1c9ee5bb1f81f 100644 --- a/rllib/examples/custom_observation_filters.py +++ b/rllib/examples/custom_observation_filters.py @@ -137,8 +137,6 @@ def __repr__(self): } results = tune.run( - args.run, - config=config, - stop={"training_iteration": args.stop_iters}) + args.run, config=config, stop={"training_iteration": args.stop_iters}) ray.shutdown() diff --git a/rllib/examples/env/random_env.py b/rllib/examples/env/random_env.py index ca650c4f2ccc8..b6b451fef7c33 100644 --- a/rllib/examples/env/random_env.py +++ b/rllib/examples/env/random_env.py @@ -2,7 +2,7 @@ from gym.spaces import Discrete, Tuple import numpy as np -from ray.rllib.examples.env.multi_agent import make_multiagent +from ray.rllib.examples.env.multi_agent import make_multi_agent class RandomEnv(gym.Env): @@ -62,4 +62,4 @@ def step(self, action): # Multi-agent version of the RandomEnv. -RandomMultiAgentEnv = make_multiagent(lambda c: RandomEnv(c)) +RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c)) diff --git a/rllib/examples/preprocessing_disabled.py b/rllib/examples/preprocessing_disabled.py index 4cd96a7848d12..d492b412d4a80 100644 --- a/rllib/examples/preprocessing_disabled.py +++ b/rllib/examples/preprocessing_disabled.py @@ -84,7 +84,8 @@ def get_cli_args(): "a": Discrete(2), "b": Dict({ "ba": Discrete(3), - "bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)}), + "bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32) + }), "c": Tuple((MultiDiscrete([2, 3]), Discrete(2))), "d": Box(-1.0, 1.0, (2, ), dtype=np.int32), }), diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 0aa7dffbc8308..59de0ea4a19e2 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -670,7 +670,7 @@ def get_preprocessor(env: gym.Env, options) @staticmethod - @Deprecated + @Deprecated(error=False) def get_preprocessor_for_space(observation_space: gym.Space, options: dict = None) -> Preprocessor: """Returns a suitable preprocessor for the given observation space. @@ -709,7 +709,7 @@ def get_preprocessor_for_space(observation_space: gym.Space, return prep @staticmethod - @Deprecated + @Deprecated(error=False) def register_custom_preprocessor(preprocessor_name: str, preprocessor_class: type) -> None: """Register a custom preprocessor class by name. @@ -816,10 +816,10 @@ def _get_v2_model_class(input_space: gym.Space, # processes image components with default CNN stacks). space_to_check = input_space if not hasattr( input_space, "original_space") else input_space.original_space - if isinstance(input_space, - (Dict, Tuple)) or (isinstance(space_to_check, (Dict, Tuple)) and any( - isinstance(s, Box) and len(s.shape) >= 2 - for s in tree.flatten(space_to_check.spaces))): + if isinstance(input_space, (Dict, Tuple)) or (isinstance( + space_to_check, (Dict, Tuple)) and any( + isinstance(s, Box) and len(s.shape) >= 2 + for s in tree.flatten(space_to_check.spaces))): return ComplexNet # Single, flattenable/one-hot-able space -> Simple FCNet. diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 06ecbd6508fb4..33ef18b5378a3 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -177,7 +177,7 @@ def write(self, observation: TensorType, array: np.ndarray, array[offset:offset + self.size] = self.transform(observation) -@Deprecated +@Deprecated(error=False) class NoPreprocessor(Preprocessor): @override(Preprocessor) def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index d16adc4c2c71a..fefdfb28d2fd4 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -134,15 +134,15 @@ def forward(self, input_dict, state, seq_lens): outs.append(cnn_out) elif i in self.one_hot: #if component.dtype in [tf.int32, tf.int64, tf.uint8]: - outs.append( - one_hot(component, self.flattened_input_space[i])) + outs.append(one_hot(component, self.flattened_input_space[i])) #else: # outs.append(tf.cast(component, tf.float32)) else: - outs.append(tf.cast( - tf.reshape(component, [-1, self.flatten[i]]), - dtype=tf.float32, - )) + outs.append( + tf.cast( + tf.reshape(component, [-1, self.flatten[i]]), + dtype=tf.float32, + )) # Concat all outputs and the non-image inputs. out = tf.concat(outs, axis=1) # Push through (optional) FC-stack (this may be an empty stack). diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 27682dfba4798..b368b68c161cb 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -506,8 +506,8 @@ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): if batch_size >= len(self._loaded_single_cpu_batch): sliced_batch = self._loaded_single_cpu_batch else: - sliced_batch = self._loaded_single_cpu_batch.slice( - start=offset, end=offset + batch_size) + sliced_batch = self._loaded_single_cpu_batch[offset:offset + + batch_size] return self.learn_on_batch(sliced_batch) return self.multi_gpu_tower_stacks[buffer_index].optimize( @@ -542,7 +542,7 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, # Skip action dist inputs placeholder (do later). elif view_col == SampleBatch.ACTION_DIST_INPUTS: continue - # This is a tower, input placeholders already exist. + # This is a tower, input placeholders already exist. elif view_col in existing_inputs: input_dict[view_col] = existing_inputs[view_col] # All others. diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index ccacc0af1767d..9ebdbec71b0fd 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -860,10 +860,16 @@ def _get_dummy_batch_from_view_requirements( else: # Range of indices on time-axis, e.g. "-50:-1". if view_req.shift_from is not None: - ret[view_col] = get_dummy_batch_for_space(view_req.space, batch_size, time_size=view_req.shift_to - view_req.shift_from + 1) + ret[view_col] = get_dummy_batch_for_space( + view_req.space, + batch_size, + time_size=view_req.shift_to - view_req.shift_from + 1) # Sequence of (probably non-consecutive) indices. elif isinstance(view_req.shift, (list, tuple)): - ret[view_col] = get_dummy_batch_for_space(view_req.space, batch_size, time_size=len(view_req.shift)) + ret[view_col] = get_dummy_batch_for_space( + view_req.space, + batch_size, + time_size=len(view_req.shift)) # Single shift int value. else: if isinstance(view_req.space, gym.spaces.Space): diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 62113924490cf..f1d5cb91ca8dd 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -1029,8 +1029,12 @@ def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool): # Build the feed dict from the batch. feed_dict = {} - for key, placeholder in self._loss_input_dict.items(): - feed_dict[placeholder] = train_batch[key] + for key, placeholders in self._loss_input_dict.items(): + tree.map_structure( + lambda ph, v: feed_dict.__setitem__(ph, v), + placeholders, + train_batch[key], + ) state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index 4ca507943a3e5..0cc672970b85a 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -65,12 +65,13 @@ def _helper_struct(space_): return _helper_struct(space) -def get_dummy_batch_for_space(space: gym.Space, - batch_size: int = 32, - fill_value: Union[float, int, str] = 0.0, - time_size: Optional[int] = None, - time_major: bool = False, - ) -> np.ndarray: +def get_dummy_batch_for_space( + space: gym.Space, + batch_size: int = 32, + fill_value: Union[float, int, str] = 0.0, + time_size: Optional[int] = None, + time_major: bool = False, +) -> np.ndarray: """Returns batched dummy data (using `batch_size`) for the given `space`. Note: The returned batch will not pass a `space.contains(batch)` test @@ -104,18 +105,19 @@ def get_dummy_batch_for_space(space: gym.Space, if time_size is not None: assert batch_size > 0 and time_size > 0 if time_major: - return np.array([ - [space.sample() for _ in range(batch_size)] - for t in range(time_size) - ], dtype=space.dtype) + return np.array( + [[space.sample() for _ in range(batch_size)] + for t in range(time_size)], + dtype=space.dtype) else: - return np.array([ - [space.sample() for t in range(time_size)] - for _ in range(batch_size) - ], dtype=space.dtype) + return np.array( + [[space.sample() for t in range(time_size)] + for _ in range(batch_size)], + dtype=space.dtype) else: return np.array( - [space.sample() for _ in range(batch_size)] if batch_size > 0 else space.sample(), + [space.sample() for _ in range(batch_size)] + if batch_size > 0 else space.sample(), dtype=space.dtype) # Fill value given: Use np.full. else: @@ -127,7 +129,10 @@ def get_dummy_batch_for_space(space: gym.Space, shape = [batch_size, time_size] else: shape = [batch_size] if batch_size > 0 else [] - return np.full(shape + list(space.shape), fill_value=fill_value, dtype=space.dtype) + return np.full( + shape + list(space.shape), + fill_value=fill_value, + dtype=space.dtype) def flatten_to_single_ndarray(input_): diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index b8b8a23ce5787..0f76a60affd5a 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -56,7 +56,12 @@ def get_gpu_devices(): return [d.name for d in devices if "GPU" in d.device_type] -def get_placeholder(*, space=None, value=None, name=None, time_axis=False, flatten=True): +def get_placeholder(*, + space=None, + value=None, + name=None, + time_axis=False, + flatten=True): from ray.rllib.models.catalog import ModelCatalog if space is not None: @@ -122,7 +127,10 @@ def one_hot(x, space): return tf.one_hot(x, space.n, dtype=tf.float32) elif isinstance(space, MultiDiscrete): return tf.concat( - [tf.one_hot(x[:, i], n, dtype=tf.float32) for i, n in enumerate(space.nvec)], + [ + tf.one_hot(x[:, i], n, dtype=tf.float32) + for i, n in enumerate(space.nvec) + ], axis=-1) else: raise ValueError("Unsupported space for `one_hot`: {}".format(space)) @@ -186,6 +194,7 @@ def make_wrapper(fn): if session_or_none is not None: args_placeholders = [] kwargs_placeholders = {} + symbolic_out = [None] def call(*args, **kwargs): @@ -217,44 +226,21 @@ def _create_placeholders(path, value): name=".".join([str(p) for p in path]), ) - args_placeholders = tree.map_structure_with_path( + placeholders = tree.map_structure_with_path( _create_placeholders, args) - #for i, v in enumerate(args): - # if dynamic_shape: - # if len(v.shape) > 0: - # shape = (None, ) + v.shape[1:] - # else: - # shape = () - # else: - # shape = v.shape - # args_placeholders.append( - # tf1.placeholder( - # dtype=v.dtype, - # shape=shape, - # name="arg_{}".format(i))) - - kwargs_placeholders = tree.map_structure_with_path( - _create_placeholders, kwargs) + for ph in tree.flatten(placeholders): + args_placeholders.append(ph) - #for k, v in kwargs.items(): - # if dynamic_shape: - # if len(v.shape) > 0: - # shape = (None, ) + v.shape[1:] - # else: - # shape = () - # else: - # shape = v.shape - # kwargs_placeholders[k] = \ - # tf1.placeholder( - # dtype=v.dtype, - # shape=shape, - # name="kwarg_{}".format(k)) - symbolic_out[0] = fn( - *args_placeholders, **kwargs_placeholders) - feed_dict = dict(zip(args_placeholders, args)) - tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), kwargs_placeholders, kwargs) - #{kwargs_placeholders[k]: kwargs[k] - # for k in kwargs.keys()}) + placeholders = tree.map_structure_with_path( + _create_placeholders, kwargs) + for k, ph in placeholders.items(): + kwargs_placeholders[k] = ph + + symbolic_out[0] = fn(*args_placeholders, + **kwargs_placeholders) + feed_dict = dict(zip(args_placeholders, tree.flatten(args))) + tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), + kwargs_placeholders, kwargs) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret From 3d719d00cabff94dbb3c40b07c722ae3e9d4844a Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 12 Aug 2021 22:45:22 +0200 Subject: [PATCH 15/45] wip and LINT. --- .../collectors/simple_list_collector.py | 48 +++++-------------- rllib/evaluation/rollout_worker.py | 6 ++- rllib/models/catalog.py | 10 ++-- rllib/models/tf/complex_input_net.py | 3 -- rllib/models/torch/complex_input_net.py | 6 +-- rllib/policy/dynamic_tf_policy.py | 3 +- rllib/policy/policy.py | 3 +- rllib/policy/tf_policy.py | 3 +- rllib/tests/test_catalog.py | 1 - rllib/utils/tf_ops.py | 11 +++-- 10 files changed, 38 insertions(+), 56 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index f69a16325ecad..0c44bbd450a6f 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -125,12 +125,12 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) - #if k == SampleBatch.OBS: - # tree.map_structure_with_path(self._add_obs_helper, v) - #else: - flattened = tree.flatten(v) - for i, sub_list in enumerate(self.buffers[k]): - sub_list.append(flattened[i]) + if k != SampleBatch.INFOS: + flattened = tree.flatten(v) + for i, sub_list in enumerate(self.buffers[k]): + sub_list.append(flattened[i]) + else: + self.buffers[k][0].append(v) self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: @@ -336,8 +336,6 @@ def __init__(self, policy: Policy): policy (Policy): The policy object. """ - #self.buffers: Dict[str, Any] = {} - #collections.defaultdict(list) self.batches = [] self.policy = policy # The total timestep count for all agents that use this policy. @@ -345,8 +343,6 @@ def __init__(self, policy: Policy): # agentB, both using this policy, acting in the same episode and both # doing n steps would increase the count by 2*n. self.agent_steps = 0 - # Seq-lens list of already added agent batches. - #self.seq_lens = [] if policy.is_recurrent() else None def add_postprocessed_batch_for_training( self, batch: SampleBatch, @@ -361,13 +357,6 @@ def add_postprocessed_batch_for_training( view-column needs to be copied at all (not needed for training). """ - #for view_col, data in batch.items(): - # 1) If col is not in view_requirements, we must have a direct - # child of the base Policy that doesn't do auto-view req creation. - # 2) Col is in view-reqs and needed for training. - # view_req = view_requirements.get(view_col) - # if view_req is None or view_req.used_for_training: - # self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count # And remove columns not needed for training. @@ -375,13 +364,6 @@ def add_postprocessed_batch_for_training( if view_col in batch and not view_req.used_for_training: del batch[view_col] self.batches.append(batch) - # Adjust the seq-lens array depending on the incoming agent sequences. - #if self.seq_lens is not None: - # max_seq_len = self.policy.config["model"]["max_seq_len"] - # count = batch.count - # while count > 0: - # self.seq_lens.append(min(count, max_seq_len)) - # count -= max_seq_len def build(self): """Builds a SampleBatch for this policy from the collected data. @@ -393,17 +375,11 @@ def build(self): this policy. """ # Create batch from our buffers. - #batch = SampleBatch({ - # k: tree.unflatten_as(, v) for k, v in self.buffers.items() - #}, seq_lens=self.seq_lens) batch = SampleBatch.concat_samples(self.batches) - # Clear buffers for future samples. - #self.buffers.clear() + # Clear batches for future samples. self.batches = [] # Reset agent steps to 0 and seq-lens to empty list. self.agent_steps = 0 - #if self.seq_lens is not None: - # self.seq_lens = [] return batch @@ -620,8 +596,10 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # Buffer for the data does not exist yet: Create dummy # (zero) data. if data_col not in buffers[k]: - fill_value = get_dummy_batch_for_space(view_req.space, batch_size=0) \ - if isinstance(view_req.space, Space) else \ + fill_value = get_dummy_batch_for_space( + view_req.space, + batch_size=0, + ) if isinstance(view_req.space, Space) else \ view_req.space self.agent_collectors[k]._build_buffers({ data_col: fill_value @@ -635,8 +613,8 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # `shift_from` and `shift_to` are defined: User wants a # view with some time-range. if isinstance(time_indices, tuple): - # `shift_to` == -1: Until the end (including(!) the last - # item). + # `shift_to` == -1: Until the end (including(!) the + # last item). if time_indices[1] == -1: for d, b in zip(data, buffers[k][data_col]): d.append(b[time_indices[0]:]) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 208c869771758..9d60856636927 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -387,7 +387,8 @@ def gen_rollouts(): self.count_steps_by: str = count_steps_by self.batch_mode: str = batch_mode self.compress_observations: bool = compress_observations - self.preprocessing_enabled: bool = False if preprocessor_pref is None else True + self.preprocessing_enabled: bool = False \ + if preprocessor_pref is None else True self.observation_filter = observation_filter self.last_batch: SampleBatchType = None self.global_vars: dict = None @@ -1361,7 +1362,8 @@ def _build_policy_map( preprocessor = ModelCatalog.get_preprocessor_for_space( obs_space, merged_conf.get("model")) self.preprocessors[name] = preprocessor - obs_space = preprocessor.observation_space + if preprocessor is not None: + obs_space = preprocessor.observation_space else: self.preprocessors[name] = None diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 59de0ea4a19e2..2821d8fe03a04 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -702,10 +702,14 @@ def get_preprocessor_for_space(observation_space: gym.Space, observation_space, options) else: cls = get_preprocessor(observation_space) - prep = cls(observation_space, options) + if cls is not None: + prep = cls(observation_space, options) + else: + prep = None - logger.debug("Created preprocessor {}: {} -> {}".format( - prep, observation_space, prep.shape)) + if prep is not None: + logger.debug("Created preprocessor {}: {} -> {}".format( + prep, observation_space, prep.shape)) return prep @staticmethod diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index fefdfb28d2fd4..31679e44ee4da 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -133,10 +133,7 @@ def forward(self, input_dict, state, seq_lens): cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: - #if component.dtype in [tf.int32, tf.int64, tf.uint8]: outs.append(one_hot(component, self.flattened_input_space[i])) - #else: - # outs.append(tf.cast(component, tf.float32)) else: outs.append( tf.cast( diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index db72d6c9e3b14..4fb71dabeed7a 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -138,11 +138,7 @@ def forward(self, input_dict, state, seq_lens): cnn_out, _ = self.cnns[i]({"obs": component}) outs.append(cnn_out) elif i in self.one_hot: - if component.dtype in [torch.int32, torch.int64, torch.uint8]: - outs.append( - one_hot(component, self.original_space.spaces[i])) - else: - outs.append(component) + outs.append(one_hot(component, self.original_space.spaces[i])) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index b368b68c161cb..bddbab2b290d0 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -551,7 +551,8 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, if view_req.used_for_training: # Create a +time-axis placeholder if the shift is not an # int (range or list of ints). - flatten = view_col not in [SampleBatch.OBS, SampleBatch.NEXT_OBS] or \ + flatten = view_col not in [ + SampleBatch.OBS, SampleBatch.NEXT_OBS] or \ self.config["preprocessor_pref"] is not None input_dict[view_col] = get_placeholder( space=view_req.space, diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 9ebdbec71b0fd..f5096280594a0 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -852,7 +852,8 @@ def _get_dummy_batch_from_view_requirements( ret = {} for view_col, view_req in self.view_requirements.items(): if self.config["preprocessor_pref"] is not None and \ - isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)): + isinstance(view_req.space, + (gym.spaces.Dict, gym.spaces.Tuple)): _, shape = ModelCatalog.get_action_shape( view_req.space, framework=self.config["framework"]) ret[view_col] = \ diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index f1d5cb91ca8dd..f170e0cbbf4bd 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -887,7 +887,8 @@ def _build_compute_actions(self, if key in self._input_dict: # Handle complex/nested spaces as well. tree.map_structure( - lambda k, v: builder.add_feed_dict({k: v}), self._input_dict[key], value + lambda k, v: builder.add_feed_dict({k: v}), + self._input_dict[key], value, ) # For policies that inherit directly from TFPolicy. else: diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index b096bd9ae5852..7fe99e7db65f6 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -1,5 +1,4 @@ from functools import partial -import gym from gym.spaces import Box, Dict, Discrete import numpy as np import unittest diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 0f76a60affd5a..cfabfca4636f4 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -70,7 +70,10 @@ def get_placeholder(*, return ModelCatalog.get_action_placeholder(space, None) else: return tree.map_structure_with_path( - lambda path, component: get_placeholder(space=component, name=name + "." + ".".join([str(p) for p in path])), + lambda path, component: get_placeholder( + space=component, + name=name + "." + ".".join([str(p) for p in path]), + ), get_base_struct_from_space(space), ) return tf1.placeholder( @@ -206,9 +209,9 @@ def call(*args, **kwargs): args_flat.append(a) args = args_flat - # We have not built any placeholders yet: Do this once here, then - # reuse the same placeholders each time we call this function - # again. + # We have not built any placeholders yet: Do this once here, + # then reuse the same placeholders each time we call this + # function again. if symbolic_out[0] is None: with session_or_none.graph.as_default(): From 97182abab1892a7dabd66db1a5ccda514ce3b079 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 13 Aug 2021 10:12:05 +0200 Subject: [PATCH 16/45] fix. --- rllib/evaluation/collectors/simple_list_collector.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 0c44bbd450a6f..4fd27fa865368 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -315,10 +315,14 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, "env_id", "t" ] else 0) - # Store all data as flattened lists. - self.buffers[col] = [[v for _ in range(shift)] - for v in tree.flatten(data)] - self.buffer_structs[col] = data + + # Store all data as flattened lists, except INFOS. + if col == SampleBatch.INFOS: + self.buffers[col] = [[data]] + else: + self.buffers[col] = [[v for _ in range(shift)] + for v in tree.flatten(data)] + self.buffer_structs[col] = data class _PolicyCollector: From c664afa8d9273c09129795075ef8809e7ab34e77 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 13 Aug 2021 11:07:49 +0200 Subject: [PATCH 17/45] fix. --- .../evaluation/collectors/simple_list_collector.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 4fd27fa865368..dedc943a6599e 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -125,12 +125,12 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) - if k != SampleBatch.INFOS: + if k == SampleBatch.INFOS or k.startswith("state_out_"): + self.buffers[k][0].append(v) + else: flattened = tree.flatten(v) for i, sub_list in enumerate(self.buffers[k]): sub_list.append(flattened[i]) - else: - self.buffers[k][0].append(v) self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: @@ -316,8 +316,11 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: "env_id", "t" ] else 0) - # Store all data as flattened lists, except INFOS. - if col == SampleBatch.INFOS: + # Store all data as flattened lists, except INFOS and state-out + # lists. Reason: These are monolithic items (infos is a dict that + # should not be further split, same for state-out items, which could + # be custom dicts as well). + if col == SampleBatch.INFOS or col.startswith("state_out_"): self.buffers[col] = [[data]] else: self.buffers[col] = [[v for _ in range(shift)] From e810da5e7aceefe6ae32f78c131be907a030f201 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 14 Aug 2021 18:04:47 +0200 Subject: [PATCH 18/45] fixes. --- .../collectors/simple_list_collector.py | 16 +++++++++++-- rllib/policy/sample_batch.py | 23 +++++++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index dedc943a6599e..5fbdc59b9b041 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -49,7 +49,8 @@ class _AgentCollector: _next_unroll_id = 0 # disambiguates unrolls within a single episode - def __init__(self, view_reqs): + def __init__(self, view_reqs, policy): + self.policy = policy # Determine the size of the buffer we need for data before the actual # episode starts. This is used for 0-buffering of e.g. prev-actions, # or internal state inputs. @@ -278,6 +279,17 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # may not all have the same batch size. batch = SampleBatch(batch_data) + # Adjust the seq-lens array depending on the incoming agent sequences. + if self.policy.is_recurrent(): + seq_lens = [] + max_seq_len = self.policy.config["model"]["max_seq_len"] + count = batch.count + while count > 0: + seq_lens.append(min(count, max_seq_len)) + count -= max_seq_len + batch["seq_lens"] = np.array(seq_lens) + batch.max_seq_len = max_seq_len + # Add EPS_ID and UNROLL_ID to batch. batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count) if SampleBatch.UNROLL_ID not in batch: @@ -507,7 +519,7 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, # Add initial obs to Trajectory. assert agent_key not in self.agent_collectors # TODO: determine exact shift-before based on the view-req shifts. - self.agent_collectors[agent_key] = _AgentCollector(view_reqs) + self.agent_collectors[agent_key] = _AgentCollector(view_reqs, policy) self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, agent_index=episode._agent_index(agent_id), diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index f8a738660dc01..958f5aea1123a 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -61,13 +61,32 @@ class SampleBatch(dict): @PublicAPI def __init__(self, *args, **kwargs): - """Constructs a sample batch (same params as dict constructor).""" + """Constructs a sample batch (same params as dict constructor). + + Note: All *args and those **kwargs not listed below will be passed + as-is to the parent dict constructor. + + Keyword Args: + _time_major (Optinal[bool]): Whether data in this sample batch + is time-major. This is False by default and only relevant + if the data contains sequences. + _max_seq_len (Optional[bool]): The max sequence chunk length + if the data contains sequences. + _zero_padded (Optional[bool]): Whether the data in this batch + contains sequences AND these sequences are right-zero-padded + according to the `_max_seq_len` setting. + _is_training (Optional[bool]): Whether this batch is used for + training. If False, batch may be used for e.g. action + computations (inference). + """ # Possible seq_lens (TxB or BxT) setup. self.time_major = kwargs.pop("_time_major", None) - + # Maximum seq len value. self.max_seq_len = kwargs.pop("_max_seq_len", None) + # Is alredy right-zero-padded? self.zero_padded = kwargs.pop("_zero_padded", False) + # Whether this batch is used for training (vs inference). self.is_training = kwargs.pop("_is_training", None) # Call super constructor. This will make the actual data accessible From 29803469055590075e789b84e0926a017c309097 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sun, 15 Aug 2021 20:02:34 +0200 Subject: [PATCH 19/45] fixes. --- rllib/env/multi_agent_env.py | 2 +- rllib/models/preprocessors.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 3fc8b8dd91dc6..59b69098c633b 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -142,7 +142,7 @@ def make_multi_agent(env_name_or_creator): Returns: Type[MultiAgentEnv]: New MultiAgentEnv class to be used as env. The constructor takes a config dict with `num_agents` key - (default=1). The reset of the config dict will be passed on to the + (default=1). The rest of the config dict will be passed on to the underlying single-agent env's constructor. Examples: diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 33ef18b5378a3..9f51ccc604b98 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -214,9 +214,14 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: for i in range(len(self._obs_space.spaces)): space = self._obs_space.spaces[i] logger.debug("Creating sub-preprocessor for {}".format(space)) - preprocessor = get_preprocessor(space)(space, self._options) + preprocessor_class = get_preprocessor(space) + if preprocessor_class is not None: + preprocessor = preprocessor_class(space, self._options) + size += preprocessor.size + else: + preprocessor = None + size += int(np.product(space.shape)) self.preprocessors.append(preprocessor) - size += preprocessor.size return (size, ) @override(Preprocessor) @@ -248,9 +253,14 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: self.preprocessors = [] for space in self._obs_space.spaces.values(): logger.debug("Creating sub-preprocessor for {}".format(space)) - preprocessor = get_preprocessor(space)(space, self._options) + preprocessor_class = get_preprocessor(space) + if preprocessor_class is not None: + preprocessor = preprocessor_class(space, self._options) + size += preprocessor.size + else: + preprocessor = None + size += int(np.product(space.shape)) self.preprocessors.append(preprocessor) - size += preprocessor.size return (size, ) @override(Preprocessor) From 481ed042a11fa1894a84f50d9063c8e13b951466 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 16 Aug 2021 06:02:44 +0200 Subject: [PATCH 20/45] wip. --- rllib/evaluation/collectors/simple_list_collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 5fbdc59b9b041..63ca693a9950f 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -329,7 +329,7 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: ] else 0) # Store all data as flattened lists, except INFOS and state-out - # lists. Reason: These are monolithic items (infos is a dict that + # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which could # be custom dicts as well). if col == SampleBatch.INFOS or col.startswith("state_out_"): From a72a7c03ae35f9d77e5389d50bcdf9b125fdeb03 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 18 Aug 2021 19:21:58 +0200 Subject: [PATCH 21/45] wip. --- rllib/agents/a3c/a3c_tf_policy.py | 4 +- rllib/agents/a3c/a3c_torch_policy.py | 5 +- rllib/agents/dqn/r2d2_tf_policy.py | 7 +- rllib/agents/dqn/r2d2_torch_policy.py | 6 +- rllib/agents/impala/vtrace_tf_policy.py | 10 +-- rllib/agents/impala/vtrace_torch_policy.py | 11 +-- rllib/agents/ppo/appo_tf_policy.py | 10 +-- rllib/agents/ppo/appo_torch_policy.py | 10 +-- rllib/agents/ppo/ppo_tf_policy.py | 4 +- rllib/agents/ppo/ppo_torch_policy.py | 4 +- rllib/agents/sac/rnnsac_torch_policy.py | 4 +- .../tests/test_trajectory_view_api.py | 6 +- rllib/examples/models/modelv3.py | 4 +- rllib/models/modelv2.py | 5 +- rllib/models/tf/attention_net.py | 2 +- rllib/models/tf/recurrent_net.py | 6 +- rllib/policy/dynamic_tf_policy.py | 14 ++-- rllib/policy/policy.py | 4 +- rllib/policy/rnn_sequencing.py | 19 +++-- rllib/policy/sample_batch.py | 83 ++++++++++--------- rllib/policy/tests/test_sample_batch.py | 45 +++++----- rllib/policy/tf_policy.py | 4 +- rllib/policy/torch_policy.py | 8 +- rllib/tests/test_lstm.py | 13 +-- 24 files changed, 154 insertions(+), 134 deletions(-) diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index 321cad2cf278e..cbc5bbbd797d6 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -75,8 +75,8 @@ def actor_critic_loss(policy: Policy, model: ModelV2, model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if policy.is_recurrent(): - max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) + max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) else: mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 8850a0767d1a3..9a876e81d3b00 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -41,8 +41,9 @@ def actor_critic_loss(policy: Policy, model: ModelV2, values = model.value_function() if policy.is_recurrent(): - max_seq_len = torch.max(train_batch["seq_lens"]) - mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) + max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) + mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], + max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index 338e9c0b48585..a38bae28c6841 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -81,7 +81,7 @@ def r2d2_loss(policy: Policy, model, _, model, train_batch, state_batches=state_batches, - seq_lens=train_batch.get("seq_lens"), + seq_lens=train_batch.get(SampleBatch.SEQ_LENS), explore=False, is_training=True) @@ -91,7 +91,7 @@ def r2d2_loss(policy: Policy, model, _, policy.target_model, train_batch, state_batches=state_batches, - seq_lens=train_batch.get("seq_lens"), + seq_lens=train_batch.get(SampleBatch.SEQ_LENS), explore=False, is_training=True) @@ -140,7 +140,8 @@ def r2d2_loss(policy: Policy, model, _, config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. - seq_mask = tf.sequence_mask(train_batch["seq_lens"], T)[:, :-1] + seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], + T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] # Making sure, this works for both static graph and eager. diff --git a/rllib/agents/dqn/r2d2_torch_policy.py b/rllib/agents/dqn/r2d2_torch_policy.py index d854b37da8685..8ec00a4a6654f 100644 --- a/rllib/agents/dqn/r2d2_torch_policy.py +++ b/rllib/agents/dqn/r2d2_torch_policy.py @@ -90,7 +90,7 @@ def r2d2_loss(policy: Policy, model, _, model, train_batch, state_batches=state_batches, - seq_lens=train_batch.get("seq_lens"), + seq_lens=train_batch.get(SampleBatch.SEQ_LENS), explore=False, is_training=True) @@ -100,7 +100,7 @@ def r2d2_loss(policy: Policy, model, _, target_model, train_batch, state_batches=state_batches, - seq_lens=train_batch.get("seq_lens"), + seq_lens=train_batch.get(SampleBatch.SEQ_LENS), explore=False, is_training=True) @@ -148,7 +148,7 @@ def r2d2_loss(policy: Policy, model, _, config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. - seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1] + seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index e231aa40ba150..6d17f8429d74f 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -167,8 +167,8 @@ def build_vtrace_loss(policy, model, dist_class, train_batch): output_hidden_shape = 1 def make_time_major(*args, **kw): - return _make_time_major(policy, train_batch.get("seq_lens"), *args, - **kw) + return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), + *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] @@ -181,8 +181,8 @@ def make_time_major(*args, **kw): values = model.value_function() if policy.is_recurrent(): - max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) + max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) else: mask = tf.ones_like(rewards) @@ -223,7 +223,7 @@ def make_time_major(*args, **kw): def stats(policy, train_batch): values_batched = _make_time_major( policy, - train_batch.get("seq_lens"), + train_batch.get(SampleBatch.SEQ_LENS), policy.model.value_function(), drop_last=policy.config["vtrace"]) diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index 6c7ca907e60d7..7934dfdaa6723 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -125,8 +125,8 @@ def build_vtrace_loss(policy, model, dist_class, train_batch): output_hidden_shape = 1 def _make_time_major(*args, **kw): - return make_time_major(policy, train_batch.get("seq_lens"), *args, - **kw) + return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), + *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] @@ -145,8 +145,9 @@ def _make_time_major(*args, **kw): values = model.value_function() if policy.is_recurrent(): - max_seq_len = torch.max(train_batch["seq_lens"]) - mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) + max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) + mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], + max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) @@ -186,7 +187,7 @@ def _make_time_major(*args, **kw): policy.loss = loss values_batched = make_time_major( policy, - train_batch.get("seq_lens"), + train_batch.get(SampleBatch.SEQ_LENS), values, drop_last=policy.config["vtrace"]) policy._vf_explained_var = explained_variance( diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 7b2156a98b1f9..bce4a05c87fed 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -114,8 +114,8 @@ def appo_surrogate_loss( # TODO: (sven) deprecate this when trajectory view API gets activated. def make_time_major(*args, **kw): - return _make_time_major(policy, train_batch.get("seq_lens"), *args, - **kw) + return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), + *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] @@ -131,8 +131,8 @@ def make_time_major(*args, **kw): policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): - max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) + max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) mask = make_time_major(mask, drop_last=policy.config["vtrace"]) @@ -282,7 +282,7 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """ values_batched = _make_time_major( policy, - train_batch.get("seq_lens"), + train_batch.get(SampleBatch.SEQ_LENS), policy.model.value_function(), drop_last=policy.config["vtrace"]) diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index 7bf1bc0cc108d..625e64226a221 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -69,9 +69,9 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2, is_multidiscrete = False output_hidden_shape = 1 - def _make_time_major(*args, **kw): - return make_time_major(policy, train_batch.get("seq_lens"), *args, - **kw) + def _make_time_major(*args, **kwargs): + return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), + *args, **kwargs) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] @@ -85,8 +85,8 @@ def _make_time_major(*args, **kw): values_time_major = _make_time_major(values) if policy.is_recurrent(): - max_seq_len = torch.max(train_batch["seq_lens"]) - mask = sequence_mask(train_batch["seq_lens"], max_seq_len) + max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) + mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=policy.config["vtrace"]) num_valid = torch.sum(mask) diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index f3b6bcd0e5932..010f4cd0f37c5 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -60,10 +60,10 @@ def ppo_surrogate_loss( # Derive max_seq_len from the data itself, not from the seq_lens # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still # 0-padded up to T=5 (as it's the case for attention nets). - B = tf.shape(train_batch["seq_lens"])[0] + B = tf.shape(train_batch[SampleBatch.SEQ_LENS])[0] max_seq_len = tf.shape(logits)[0] // B - mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) def reduce_mean_valid(t): diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index b1b225251e937..44f3b5b5f363c 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -47,10 +47,10 @@ def ppo_surrogate_loss( # RNN case: Mask away 0-padded chunks at end of time axis. if state: - B = len(train_batch["seq_lens"]) + B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask = sequence_mask( - train_batch["seq_lens"], + train_batch[SampleBatch.SEQ_LENS], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) diff --git a/rllib/agents/sac/rnnsac_torch_policy.py b/rllib/agents/sac/rnnsac_torch_policy.py index 3498df450549b..c0d223c0a4766 100644 --- a/rllib/agents/sac/rnnsac_torch_policy.py +++ b/rllib/agents/sac/rnnsac_torch_policy.py @@ -213,7 +213,7 @@ def actor_critic_loss( state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches - seq_lens = train_batch.get("seq_lens") + seq_lens = train_batch.get(SampleBatch.SEQ_LENS) model_out_t, state_in_t = model({ "obs": train_batch[SampleBatch.CUR_OBS], @@ -343,7 +343,7 @@ def actor_critic_loss( # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B - seq_mask = sequence_mask(train_batch["seq_lens"], T) + seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 57392184f7270..e82be6eaca6f4 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -26,10 +26,10 @@ class MyCallbacks(DefaultCallbacks): @override(DefaultCallbacks) def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs): assert train_batch.count == 201 - assert sum(train_batch["seq_lens"]) == 201 + assert sum(train_batch[SampleBatch.SEQ_LENS]) == 201 for k, v in train_batch.items(): if k == "state_in_0": - assert len(v) == len(train_batch["seq_lens"]) + assert len(v) == len(train_batch[SampleBatch.SEQ_LENS]) else: assert len(v) == 201 current = None @@ -403,7 +403,7 @@ def analyze_rnn_batch(batch, max_seq_len): # Check after seq-len 0-padding. cursor = 0 - for i, seq_len in enumerate(batch["seq_lens"]): + for i, seq_len in enumerate(batch[SampleBatch.SEQ_LENS]): state_in_0 = batch["state_in_0"][i] state_in_1 = batch["state_in_1"][i] for j in range(seq_len): diff --git a/rllib/examples/models/modelv3.py b/rllib/examples/models/modelv3.py index 8d3d4f3468872..6ad38ab6989a8 100644 --- a/rllib/examples/models/modelv3.py +++ b/rllib/examples/models/modelv3.py @@ -36,11 +36,11 @@ def __init__(self, def call(self, sample_batch): dense_out = self.dense(sample_batch["obs"]) - B = tf.shape(sample_batch["seq_lens"])[0] + B = tf.shape(sample_batch[SampleBatch.SEQ_LENS])[0] lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]]) lstm_out, h, c = self.lstm( inputs=lstm_in, - mask=tf.sequence_mask(sample_batch["seq_lens"]), + mask=tf.sequence_mask(sample_batch[SampleBatch.SEQ_LENS]), initial_state=[ sample_batch["state_in_0"], sample_batch["state_in_1"] ], diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 5921c8b2c2662..ecc5d0cc6ac26 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -206,7 +206,7 @@ def __call__( restored = input_dict.copy(shallow=True) # Backward compatibility. if seq_lens is None: - seq_lens = input_dict.get("seq_lens") + seq_lens = input_dict.get(SampleBatch.SEQ_LENS) if not state: state = [] i = 0 @@ -260,7 +260,8 @@ def from_batch(self, train_batch: SampleBatch, while "state_in_{}".format(i) in input_dict: states.append(input_dict["state_in_{}".format(i)]) i += 1 - ret = self.__call__(input_dict, states, input_dict.get("seq_lens")) + ret = self.__call__(input_dict, states, + input_dict.get(SampleBatch.SEQ_LENS)) return ret def import_from_h5(self, h5_file: str) -> None: diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 882489ebe386b..98e97561ead74 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -790,7 +790,7 @@ def __init__( def call(self, input_dict: SampleBatch) -> \ (TensorType, List[TensorType], Dict[str, TensorType]): - assert input_dict["seq_lens"] is not None + assert input_dict[SampleBatch.SEQ_LENS] is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 29184ebb18336..8b9546021c0ac 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -365,7 +365,7 @@ def __init__( def call(self, input_dict: SampleBatch) -> \ (TensorType, List[TensorType], Dict[str, TensorType]): - assert input_dict.get("seq_lens") is not None + assert input_dict.get(SampleBatch.SEQ_LENS) is not None # Push obs through underlying (wrapped) model first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) @@ -387,11 +387,11 @@ def call(self, input_dict: SampleBatch) -> \ wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) max_seq_len = tf.shape(wrapped_out)[0] // tf.shape( - input_dict["seq_lens"])[0] + input_dict[SampleBatch.SEQ_LENS])[0] wrapped_out_plus_time_dim = add_time_dimension( wrapped_out, max_seq_len=max_seq_len, framework="tf") model_out, value_out, h, c = self._rnn_model([ - wrapped_out_plus_time_dim, input_dict["seq_lens"], + wrapped_out_plus_time_dim, input_dict[SampleBatch.SEQ_LENS], input_dict["state_in_0"], input_dict["state_in_1"] ]) model_out_no_time_dim = tf.reshape( diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 8f19b3d761c5e..75cf4b60ebbe8 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -189,7 +189,7 @@ def __init__( ] # Placeholder for RNN time-chunk valid lengths. if self._state_inputs: - self._seq_lens = existing_inputs["seq_lens"] + self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS] # Create new input placeholders. else: self._state_inputs = [ @@ -404,7 +404,7 @@ def copy(self, ("state_in_{}".format(i), existing_inputs[len(self._loss_input_dict_no_rnn) + i])) if rnn_inputs: - rnn_inputs.append(("seq_lens", existing_inputs[-1])) + rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1])) input_dict = OrderedDict( [("is_exploring", self._is_exploring), ("timestep", self._timestep)] + @@ -613,8 +613,10 @@ def _initialize_loss_from_dummy_batch( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: - train_batch["seq_lens"] = self._seq_lens - self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]}) + train_batch[SampleBatch.SEQ_LENS] = self._seq_lens + self._loss_input_dict.update({ + SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS] + }) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) @@ -632,8 +634,8 @@ def _initialize_loss_from_dummy_batch( TFPolicy._initialize_loss(self, loss, [ (k, v) for k, v in train_batch.items() if k in all_accessed_keys - ] + ([("seq_lens", train_batch["seq_lens"])] - if "seq_lens" in train_batch else [])) + ] + ([(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])] + if SampleBatch.SEQ_LENS in train_batch else [])) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index aa6008acae3d3..d8db60b4173e0 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -780,13 +780,13 @@ def _initialize_loss_from_dummy_batch( i += 1 seq_len = sample_batch_size // B seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32) - postprocessed_batch["seq_lens"] = seq_lens + postprocessed_batch[SampleBatch.SEQ_LENS] = seq_lens # Switch on lazy to-tensor conversion on `postprocessed_batch`. train_batch = self._lazy_tensor_dict(postprocessed_batch) # Calling loss, so set `is_training` to True. train_batch.is_training = True if seq_lens is not None: - train_batch["seq_lens"] = seq_lens + train_batch[SampleBatch.SEQ_LENS] = seq_lens train_batch.count = self._dummy_batch.count # Call the loss function, if it exists. if self._loss is not None: diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 22154d7eb78ec..de5b2529df684 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -80,8 +80,8 @@ def pad_batch_to_sequences_of_same_size( if "state_in_0" in batch or "state_out_0" in batch: # Check, whether the state inputs have already been reduced to their # init values at the beginning of each max_seq_len chunk. - if batch.get("seq_lens") is not None and \ - len(batch["state_in_0"]) == len(batch["seq_lens"]): + if batch.get(SampleBatch.SEQ_LENS) is not None and \ + len(batch["state_in_0"]) == len(batch[SampleBatch.SEQ_LENS]): states_already_reduced_to_init = True # RNN (or single timestep state-in): Set the max dynamically. @@ -109,7 +109,8 @@ def pad_batch_to_sequences_of_same_size( if k.startswith("state_in_"): state_keys.append(k) elif not feature_keys and not k.startswith("state_out_") and \ - k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray): + k not in ["infos", SampleBatch.SEQ_LENS] and \ + isinstance(v, np.ndarray): feature_keys_.append(k) feature_sequences, initial_states, seq_lens = \ @@ -119,7 +120,7 @@ def pad_batch_to_sequences_of_same_size( episode_ids=batch.get(SampleBatch.EPS_ID), unroll_ids=batch.get(SampleBatch.UNROLL_ID), agent_indices=batch.get(SampleBatch.AGENT_INDEX), - seq_lens=batch.get("seq_lens"), + seq_lens=batch.get(SampleBatch.SEQ_LENS), max_seq_len=max_seq_len, dynamic_max=dynamic_max, states_already_reduced_to_init=states_already_reduced_to_init, @@ -129,7 +130,7 @@ def pad_batch_to_sequences_of_same_size( batch[k] = feature_sequences[i] for i, k in enumerate(state_keys): batch[k] = initial_states[i] - batch["seq_lens"] = np.array(seq_lens) + batch[SampleBatch.SEQ_LENS] = np.array(seq_lens) if log_once("rnn_ma_feed_dict"): logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format( @@ -330,7 +331,7 @@ def timeslice_along_seq_lens_with_overlap( Args: sample_batch (SampleBatch): The SampleBatch to timeslice. seq_lens (Optional[List[int]]): An optional list of seq_lens to slice - at. If None, use `sample_batch["seq_lens"]`. + at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`. zero_pad_max_seq_len (int): If >0, already zero-pad the resulting slices up to this length. NOTE: This max-len will include the additional timesteps gained via setting pre_overlap (see Example). @@ -360,7 +361,7 @@ def timeslice_along_seq_lens_with_overlap( # count (makes sure each slice has exactly length 10). """ if seq_lens is None: - seq_lens = sample_batch.get("seq_lens") + seq_lens = sample_batch.get(SampleBatch.SEQ_LENS) assert seq_lens is not None and len(seq_lens) > 0, \ "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!" # Generate n slices based on seq_lens. @@ -398,12 +399,12 @@ def timeslice_along_seq_lens_with_overlap( shape=(zero_length, ) + v.shape[1:], dtype=v.dtype), v[data_begin:end] ]) - for k, v in sample_batch.items() if k != "seq_lens" + for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } else: data = { k: v[begin:end] - for k, v in sample_batch.items() if k != "seq_lens" + for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } if zero_init_states_: diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 03e015ccf84b1..86a880b9ad1d4 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -81,13 +81,14 @@ def __init__(self, *args, **kwargs): self.get_interceptor = None # Clear out None seq-lens. - seq_lens_ = self.get("seq_lens") + seq_lens_ = self.get(SampleBatch.SEQ_LENS) if seq_lens_ is None or \ (isinstance(seq_lens_, list) and len(seq_lens_) == 0): - self.pop("seq_lens", None) + self.pop(SampleBatch.SEQ_LENS, None) # Numpyfy seq_lens if list. elif isinstance(seq_lens_, list): - self["seq_lens"] = seq_lens_ = np.array(seq_lens_, dtype=np.int32) + self[SampleBatch.SEQ_LENS] = seq_lens_ = \ + np.array(seq_lens_, dtype=np.int32) if self.max_seq_len is None and seq_lens_ is not None and \ not (tf and tf.is_tensor(seq_lens_)) and \ @@ -98,7 +99,7 @@ def __init__(self, *args, **kwargs): self.is_training = self.pop("is_training", False) lengths = [] - copy_ = {k: v for k, v in self.items() if k != "seq_lens"} + copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS} for k, v in copy_.items(): assert isinstance(k, str), self @@ -118,10 +119,10 @@ def __init__(self, *args, **kwargs): if len_: lengths.append(len_) - if self.get("seq_lens") is not None and \ - not (tf and tf.is_tensor(self["seq_lens"])) and \ - len(self["seq_lens"]) > 0: - self.count = sum(self["seq_lens"]) + if self.get(SampleBatch.SEQ_LENS) is not None and \ + not (tf and tf.is_tensor(self[SampleBatch.SEQ_LENS])) and \ + len(self[SampleBatch.SEQ_LENS]) > 0: + self.count = sum(self[SampleBatch.SEQ_LENS]) else: self.count = lengths[0] if lengths else 0 @@ -173,8 +174,8 @@ def concat_samples( if zero_padded: assert s.max_seq_len == max_seq_len concat_samples.append(s) - if s.get("seq_lens") is not None: - concatd_seq_lens.extend(s["seq_lens"]) + if s.get(SampleBatch.SEQ_LENS) is not None: + concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS]) # If we don't have any samples (no or only empty SampleBatches), # return an empty SampleBatch here. @@ -273,13 +274,14 @@ def rows(self) -> Dict[str, TensorType]: """ # Do we add seq_lens=[1] to each row? - seq_lens = None if self.get("seq_lens") is None else np.array([1]) + seq_lens = None if self.get( + SampleBatch.SEQ_LENS) is None else np.array([1]) self_as_dict = {k: v for k, v in self.items()} for i in range(self.count): yield tree.map_structure_with_path( - lambda p, v: v[i] if p[0] != "seq_lens" else seq_lens, + lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens, self_as_dict, ) @@ -314,7 +316,7 @@ def shuffle(self) -> None: SampleBatch: This very (now shuffled) SampleBatch. Raises: - ValueError: If self["seq_lens"] is defined. + ValueError: If self[SampleBatch.SEQ_LENS] is defined. Examples: >>> batch = SampleBatch({"a": [1, 2, 3, 4]}) @@ -324,7 +326,7 @@ def shuffle(self) -> None: # Shuffling the data when we have `seq_lens` defined is probably # a bad idea! - if self.get("seq_lens") is not None: + if self.get(SampleBatch.SEQ_LENS) is not None: raise ValueError( "SampleBatch.shuffle not possible when your data has " "`seq_lens` defined!") @@ -406,7 +408,8 @@ def slice(self, start: int, end: int, state_start=None, SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ - if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: + if self.get(SampleBatch.SEQ_LENS) is not None and \ + len(self[SampleBatch.SEQ_LENS]) > 0: if start < 0: data = { k: np.concatenate([ @@ -414,14 +417,14 @@ def slice(self, start: int, end: int, state_start=None, shape=(-start, ) + v.shape[1:], dtype=v.dtype), v[0:end] ]) - for k, v in self.items() - if k != "seq_lens" and not k.startswith("state_in_") + for k, v in self.items() if k != SampleBatch.SEQ_LENS + and not k.startswith("state_in_") } else: data = { k: v[start:end] - for k, v in self.items() - if k != "seq_lens" and not k.startswith("state_in_") + for k, v in self.items() if k != SampleBatch.SEQ_LENS + and not k.startswith("state_in_") } if state_start is not None: assert state_end is not None @@ -431,7 +434,8 @@ def slice(self, start: int, end: int, state_start=None, data[state_key] = self[state_key][state_start:state_end] state_idx += 1 state_key = "state_in_{}".format(state_idx) - seq_lens = list(self["seq_lens"][state_start:state_end]) + seq_lens = list( + self[SampleBatch.SEQ_LENS][state_start:state_end]) # Adjust seq_lens if necessary. data_len = len(data[next(iter(data))]) if sum(seq_lens) != data_len: @@ -442,7 +446,7 @@ def slice(self, start: int, end: int, state_start=None, count = 0 state_start = None seq_lens = None - for i, seq_len in enumerate(self["seq_lens"]): + for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]): count += seq_len if count >= end: state_idx = 0 @@ -454,9 +458,10 @@ def slice(self, start: int, end: int, state_start=None, 1] state_idx += 1 state_key = "state_in_{}".format(state_idx) - seq_lens = list(self["seq_lens"][state_start:i]) + [ - seq_len - (count - end) - ] + seq_lens = list( + self[SampleBatch.SEQ_LENS][state_start:i]) + [ + seq_len - (count - end) + ] if start < 0: seq_lens[0] += -start diff = sum(seq_lens) - (end - start) @@ -553,7 +558,7 @@ def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True): SampleBatch: This very (now right-zero-padded) SampleBatch. Raises: - ValueError: If self.seq_lens is None (not defined). + ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined). Examples: >>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]}) @@ -568,7 +573,7 @@ def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True): "state_in_0": [1.0, 3.0], # <- all state-ins remain as-is "seq_lens": [1, 2]} """ - seq_lens = self.get("seq_lens") + seq_lens = self.get(SampleBatch.SEQ_LENS) if seq_lens is None: raise ValueError( "Cannot right-zero-pad SampleBatch if no `seq_lens` field " @@ -579,7 +584,7 @@ def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True): def _zero_pad_in_place(path, value): # Skip "state_in_..." columns and "seq_lens". if (exclude_states is True and path[0].startswith("state_in_")) \ - or path[0] == "seq_lens": + or path[0] == SampleBatch.SEQ_LENS: return # Generate zero-filled primer of len=max_seq_len. if value.dtype == np.object or value.dtype.type is np.str_: @@ -590,7 +595,7 @@ def _zero_pad_in_place(path, value): (length, ) + np.shape(value)[1:], dtype=value.dtype) # Fill primer with data. f_pad_base = f_base = 0 - for len_ in self["seq_lens"]: + for len_ in self[SampleBatch.SEQ_LENS]: f_pad[f_pad_base:f_pad_base + len_] = value[f_base:f_base + len_] f_pad_base += max_seq_len @@ -778,10 +783,10 @@ def set_get_interceptor(self, fn): def __repr__(self): keys = list(self.keys()) - if self.get("seq_lens") is None: + if self.get(SampleBatch.SEQ_LENS) is None: return f"SampleBatch({self.count}: {keys})" else: - keys.remove("seq_lens") + keys.remove(SampleBatch.SEQ_LENS) return f"SampleBatch({self.count} " \ f"(seqs={len(self['seq_lens'])}): {keys})" @@ -806,15 +811,16 @@ def _slice(self, slice_: slice): stop = slice_.stop or len(self) assert start >= 0 and stop >= 0 and slice_.step in [1, None] - if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: + if self.get(SampleBatch.SEQ_LENS) is not None and \ + len(self[SampleBatch.SEQ_LENS]) > 0: # Build our slice-map, if not done already. if not self._slice_map: sum_ = 0 - for i, l in enumerate(self["seq_lens"]): + for i, l in enumerate(self[SampleBatch.SEQ_LENS]): for _ in range(l): self._slice_map.append((i, sum_)) sum_ += l - self._slice_map.append((len(self["seq_lens"]), sum_)) + self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_)) start_seq_len, start = self._slice_map[start] stop_seq_len, stop = self._slice_map[stop] @@ -823,7 +829,7 @@ def _slice(self, slice_: slice): stop = stop_seq_len * self.max_seq_len def map_(path, value): - if path[0] != "seq_lens" and not path[0].startswith( + if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith( "state_in_"): return value[start:stop] else: @@ -848,8 +854,9 @@ def map_(path, value): def _get_slice_indices(self, slice_size): data_slices = [] data_slices_states = [] - if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0: - assert np.all(self["seq_lens"] < slice_size), \ + if self.get(SampleBatch.SEQ_LENS) is not None and len( + self[SampleBatch.SEQ_LENS]) > 0: + assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), \ "ERROR: `slice_size` must be larger than the max. seq-len " \ "in the batch!" start_pos = 0 @@ -857,8 +864,8 @@ def _get_slice_indices(self, slice_size): actual_slice_idx = 0 start_idx = 0 idx = 0 - while idx < len(self["seq_lens"]): - seq_len = self["seq_lens"][idx] + while idx < len(self[SampleBatch.SEQ_LENS]): + seq_len = self[SampleBatch.SEQ_LENS][idx] current_slize_size += seq_len actual_slice_idx += seq_len if not self.zero_padded else \ self.max_seq_len diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index cc158fab1a97c..5ebfd3c6dca1c 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -22,11 +22,12 @@ def test_len_and_size_bytes(self): "b": { "c": np.array([4, 5, 6]) }, - "seq_lens": [1, 2], + SampleBatch.SEQ_LENS: [1, 2], }) check(len(s1), 3) - check(s1.size_bytes(), - s1["a"].nbytes + s1["b"]["c"].nbytes + s1["seq_lens"].nbytes) + check( + s1.size_bytes(), s1["a"].nbytes + s1["b"]["c"].nbytes + + s1[SampleBatch.SEQ_LENS].nbytes) def test_dict_properties_of_sample_batches(self): base_dict = { @@ -60,7 +61,7 @@ def test_right_zero_padding(self): "b": { "c": np.array([4, 5, 6]) }, - "seq_lens": [1, 2], + SampleBatch.SEQ_LENS: [1, 2], }) s1.right_zero_pad(max_seq_len=5) check( @@ -69,7 +70,7 @@ def test_right_zero_padding(self): "b": { "c": [4, 0, 0, 0, 0, 5, 6, 0, 0, 0] }, - "seq_lens": [1, 2] + SampleBatch.SEQ_LENS: [1, 2] }) def test_concat(self): @@ -100,7 +101,7 @@ def test_rows(self): "b": { "c": np.array([[4, 4], [5, 5], [6, 6]]) }, - "seq_lens": np.array([1, 2]), + SampleBatch.SEQ_LENS: np.array([1, 2]), }) check( next(s1.rows()), @@ -109,7 +110,7 @@ def test_rows(self): "b": { "c": [4, 4] }, - "seq_lens": [1] + SampleBatch.SEQ_LENS: [1] }, ) @@ -188,7 +189,7 @@ def test_slicing(self): "b": { "c": np.array([4, 5, 6, 5, 6, 7]) }, - "seq_lens": [2, 3, 1], + SampleBatch.SEQ_LENS: [2, 3, 1], "state_in_0": [1.0, 3.0, 4.0], }) # We would expect a=[1, 2, 3] now, but due to the sequence @@ -199,7 +200,7 @@ def test_slicing(self): "b": { "c": [4, 5] }, - "seq_lens": [2], + SampleBatch.SEQ_LENS: [2], "state_in_0": [1.0], }) # Split exactly at a seq-len boundary. @@ -209,7 +210,7 @@ def test_slicing(self): "b": { "c": [4, 5, 6, 5, 6] }, - "seq_lens": [2, 3], + SampleBatch.SEQ_LENS: [2, 3], "state_in_0": [1.0, 3.0], }) check( @@ -218,7 +219,7 @@ def test_slicing(self): "b": { "c": [4, 5, 6, 5, 6, 7] }, - "seq_lens": [2, 3, 1], + SampleBatch.SEQ_LENS: [2, 3, 1], "state_in_0": [1.0, 3.0, 4.0], }) @@ -228,31 +229,35 @@ def test_copy(self): "b": { "c": np.array([4, 5, 6, 5, 6, 7]) }, - "seq_lens": [2, 3, 1], + SampleBatch.SEQ_LENS: [2, 3, 1], "state_in_0": [1.0, 3.0, 4.0], }) s_copy = s.copy(shallow=False) s_copy["a"][0] = 100 s_copy["b"]["c"][0] = 200 - s_copy["seq_lens"][0] = 3 - s_copy["seq_lens"][1] = 2 + s_copy[SampleBatch.SEQ_LENS][0] = 3 + s_copy[SampleBatch.SEQ_LENS][1] = 2 s_copy["state_in_0"][0] = 400.0 self.assertNotEqual(s["a"][0], s_copy["a"][0]) self.assertNotEqual(s["b"]["c"][0], s_copy["b"]["c"][0]) - self.assertNotEqual(s["seq_lens"][0], s_copy["seq_lens"][0]) - self.assertNotEqual(s["seq_lens"][1], s_copy["seq_lens"][1]) + self.assertNotEqual(s[SampleBatch.SEQ_LENS][0], + s_copy[SampleBatch.SEQ_LENS][0]) + self.assertNotEqual(s[SampleBatch.SEQ_LENS][1], + s_copy[SampleBatch.SEQ_LENS][1]) self.assertNotEqual(s["state_in_0"][0], s_copy["state_in_0"][0]) s_copy = s.copy(shallow=True) s_copy["a"][0] = 100 s_copy["b"]["c"][0] = 200 - s_copy["seq_lens"][0] = 3 - s_copy["seq_lens"][1] = 2 + s_copy[SampleBatch.SEQ_LENS][0] = 3 + s_copy[SampleBatch.SEQ_LENS][1] = 2 s_copy["state_in_0"][0] = 400.0 self.assertEqual(s["a"][0], s_copy["a"][0]) self.assertEqual(s["b"]["c"][0], s_copy["b"]["c"][0]) - self.assertEqual(s["seq_lens"][0], s_copy["seq_lens"][0]) - self.assertEqual(s["seq_lens"][1], s_copy["seq_lens"][1]) + self.assertEqual(s[SampleBatch.SEQ_LENS][0], + s_copy[SampleBatch.SEQ_LENS][0]) + self.assertEqual(s[SampleBatch.SEQ_LENS][1], + s_copy[SampleBatch.SEQ_LENS][1]) self.assertEqual(s["state_in_0"][0], s_copy["state_in_0"][0]) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 2d9390b2ef3cd..2266fe4abf467 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -818,7 +818,7 @@ def _build_signature_def(self): tf1.saved_model.utils.build_tensor_info(self._obs_input) if self._seq_lens is not None: - input_signature["seq_lens"] = \ + input_signature[SampleBatch.SEQ_LENS] = \ tf1.saved_model.utils.build_tensor_info(self._seq_lens) if self._prev_action_input is not None: input_signature["prev_action"] = \ @@ -1033,7 +1033,7 @@ def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool): for key in state_keys: feed_dict[self._loss_input_dict[key]] = train_batch[key] if state_keys: - feed_dict[self._seq_lens] = train_batch["seq_lens"] + feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS] return feed_dict diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 7aca552fcb9e2..a61971203e035 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -859,7 +859,7 @@ def export_model(self, export_dir: str, # returned empty internal states list). if "state_in_0" not in self._dummy_batch: self._dummy_batch["state_in_0"] = \ - self._dummy_batch["seq_lens"] = np.array([1.0]) + self._dummy_batch[SampleBatch.SEQ_LENS] = np.array([1.0]) state_ins = [] i = 0 @@ -874,7 +874,7 @@ def export_model(self, export_dir: str, if not os.path.exists(export_dir): os.makedirs(export_dir) - seq_lens = self._dummy_batch["seq_lens"] + seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS] if onnx: file_name = os.path.join(export_dir, "model.onnx") torch.onnx.export( @@ -884,14 +884,14 @@ def export_model(self, export_dir: str, opset_version=onnx, do_constant_folding=True, input_names=list(dummy_inputs.keys()) + - ["state_ins", "seq_lens"], + ["state_ins", SampleBatch.SEQ_LENS], output_names=["output", "state_outs"], dynamic_axes={ k: { 0: "batch_size" } for k in list(dummy_inputs.keys()) + - ["state_ins", "seq_lens"] + ["state_ins", SampleBatch.SEQ_LENS] }) else: traced = torch.jit.trace(self.model, diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index 36c0c6a6e4b34..30d6012253bd0 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -8,6 +8,7 @@ from ray.rllib.examples.models.rnn_spy_model import RNNSpyModel from ray.rllib.models import ModelCatalog from ray.rllib.policy.rnn_sequencing import chop_into_sequences +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.test_utils import check from ray.tune.registry import register_env @@ -138,7 +139,7 @@ def test_simple_optimizer_sequencing(self): self.assertEqual( batch0["sequences"].tolist(), [[[0], [1], [2], [3]], [[4], [5], [6], [7]], [[8], [9], [0], [0]]]) - self.assertEqual(batch0["seq_lens"].tolist(), [4, 4, 2]) + self.assertEqual(batch0[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2]) self.assertEqual(batch0["state_in"][0][0].tolist(), [0, 0, 0]) self.assertEqual(batch0["state_in"][1][0].tolist(), [0, 0, 0]) self.assertGreater(abs(np.sum(batch0["state_in"][0][1])), 0) @@ -158,7 +159,7 @@ def test_simple_optimizer_sequencing(self): [[0], [1], [2], [3]], [[4], [0], [0], [0]], ]) - self.assertEqual(batch1["seq_lens"].tolist(), [4, 1, 4, 1]) + self.assertEqual(batch1[SampleBatch.SEQ_LENS].tolist(), [4, 1, 4, 1]) self.assertEqual(batch1["state_in"][0][2].tolist(), [0, 0, 0]) self.assertEqual(batch1["state_in"][1][2].tolist(), [0, 0, 0]) self.assertGreater(abs(np.sum(batch1["state_in"][0][0])), 0) @@ -198,8 +199,8 @@ def test_minibatch_sequencing(self): ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1")) if batch0["sequences"][0][0][0] > batch1["sequences"][0][0][0]: batch0, batch1 = batch1, batch0 # sort minibatches - self.assertEqual(batch0["seq_lens"].tolist(), [4, 4, 2]) - self.assertEqual(batch1["seq_lens"].tolist(), [2, 3, 4, 1]) + self.assertEqual(batch0[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2]) + self.assertEqual(batch1[SampleBatch.SEQ_LENS].tolist(), [2, 3, 4, 1]) check(batch0["sequences"], [ [[0], [1], [2], [3]], [[4], [5], [6], [7]], @@ -220,8 +221,8 @@ def test_minibatch_sequencing(self): ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_3")) if batch2["sequences"][0][0][0] > batch3["sequences"][0][0][0]: batch2, batch3 = batch3, batch2 - self.assertEqual(batch2["seq_lens"].tolist(), [4, 4, 2]) - self.assertEqual(batch3["seq_lens"].tolist(), [4, 4, 2]) + self.assertEqual(batch2[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2]) + self.assertEqual(batch3[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2]) check(batch2["sequences"], [ [[0], [1], [2], [3]], [[4], [5], [6], [7]], From 5de2cae7c294d377e11c74d87d3a8ad0f7613d0b Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 19 Aug 2021 17:56:15 +0200 Subject: [PATCH 22/45] wip. --- rllib/evaluation/collectors/simple_list_collector.py | 4 ++-- rllib/models/catalog.py | 5 +---- rllib/models/modelv2.py | 2 ++ rllib/models/preprocessors.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 63ca693a9950f..c3d873eff1149 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -330,8 +330,8 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: # Store all data as flattened lists, except INFOS and state-out # lists. These are monolithic items (infos is a dict that - # should not be further split, same for state-out items, which could - # be custom dicts as well). + # should not be further split, same for state-out items, which + # could be custom dicts as well). if col == SampleBatch.INFOS or col.startswith("state_out_"): self.buffers[col] = [[data]] else: diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 2821d8fe03a04..af86f06f1ac39 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -702,10 +702,7 @@ def get_preprocessor_for_space(observation_space: gym.Space, observation_space, options) else: cls = get_preprocessor(observation_space) - if cls is not None: - prep = cls(observation_space, options) - else: - prep = None + prep = cls(observation_space, options) if prep is not None: logger.debug("Created preprocessor {}: {} -> {}".format( diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index f5e0709a45e76..d97420b52730d 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -228,6 +228,8 @@ def __call__( restored["obs_flat"] = input_dict["obs"] except AttributeError: restored["obs_flat"] = input_dict["obs"] + # TODO: This is unnecessary for when no preprocessor is used. + # Obs are not flat then anymore. else: restored["obs_flat"] = input_dict["obs"] diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 9f51ccc604b98..4206ad90cdef7 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -342,7 +342,7 @@ def get_preprocessor(space: gym.Space) -> type: elif isinstance(space, Repeated): preprocessor = RepeatedValuesPreprocessor else: - preprocessor = None + preprocessor = NoPreprocessor return preprocessor From 4e8845034efe04b115a7c9a5a7ce323e75ca7bc1 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 19 Aug 2021 21:24:13 +0200 Subject: [PATCH 23/45] wip. --- rllib/policy/tests/test_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/policy/tests/test_policy.py b/rllib/policy/tests/test_policy.py index 8e1055e72cdf1..15e5efa918d8a 100644 --- a/rllib/policy/tests/test_policy.py +++ b/rllib/policy/tests/test_policy.py @@ -8,7 +8,7 @@ class TestPolicy(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - ray.init() + ray.init(local_mode=True)#TODO @classmethod def tearDownClass(cls) -> None: From f44a6c3a5e347ad36af495c2fd392c12ff599478 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 10:01:03 +0200 Subject: [PATCH 24/45] fixes --- .../collectors/simple_list_collector.py | 49 ++++++++++++++----- rllib/policy/policy.py | 6 +-- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index c3d873eff1149..3fb651d659688 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -60,9 +60,24 @@ def __init__(self, view_reqs, policy): (1 if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0) for k, vr in view_reqs.items()) - # The actual data buffers (lists holding each timestep's data). - self.buffers = {} - self.buffer_structs = {} + + # The actual data buffers. Keys are column names, values are lists + # that contain the sub-components (e.g. for complex obs spaces) with + # each sub-component holding a list of per-timestep tensors. + # E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,))) + # buffers["obs"] = [ + # [0, 1], # <- 1st sub-component of observation + # [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component + # ] + # NOTE: infos and state_out_... are not flattened due to them often + # using custom dict values whose structure may vary from timestep to + # timestep. + self.buffers: Dict[str, List[List[TensorType]]] = {} + # Maps column names to an example data item, which may be deeply + # nested. These are used such that we'll know how to unflatten + # the flattened data inside self.buffers when building the + # SampleBatch. + self.buffer_structs: Dict[str, Any] = {} # The episode ID for the agent for which we collect data. self.episode_id = None # The simple timestep count for this agent. Gets increased by one @@ -84,8 +99,10 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, init_obs (TensorType): The initial observation tensor (after `env.reset()`). """ - # Seems to be the first time, we call this method. Build our - # (list-based) buffers first. + # Store episode ID, which will be constant throughout this + # AgentCollector's lifecycle. + self.episode_id = episode_id + if SampleBatch.OBS not in self.buffers: self._build_buffers( single_row={ @@ -99,7 +116,6 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, flattened = tree.flatten(init_obs) for i, sub_obs in enumerate(flattened): self.buffers[SampleBatch.OBS][i].append(sub_obs) - self.episode_id = episode_id self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index) self.buffers["env_id"][0].append(env_id) self.buffers["t"][0].append(t) @@ -120,14 +136,19 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ # Make sure EPS_ID stays the same for this agent. Usually, it should # not be part of `values` anyways. if SampleBatch.EPS_ID in values: + # Make sure eps_id did not change. assert values[SampleBatch.EPS_ID] == self.episode_id + # We'll add the eps_id field later when we build the SampleBatch. del values[SampleBatch.EPS_ID] for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) + # Do not flatten infos or state_out_... (their values may be + # structs that change from timestep to timestep). if k == SampleBatch.INFOS or k.startswith("state_out_"): self.buffers[k][0].append(v) + # Flatten all other columns. else: flattened = tree.flatten(v) for i, sub_list in enumerate(self.buffers[k]): @@ -253,13 +274,14 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: data = [d[self.shift_before:] for d in np_data[data_col]] # Shift is positive: We still need to 0-pad at the end. elif shift > 0: - data = to_float_np_array( - self.buffers[data_col][self.shift_before + shift:] + [ + data = [ + to_float_np_array(d[self.shift_before + shift:] + [ np.zeros( shape=view_req.space.shape, dtype=view_req.space.dtype) for _ in range(shift) - ]) + ]) for d in np_data[data_col] + ] # Shift is negative: Shift into the already existing and # 0-padded "before" area of our buffers. else: @@ -304,11 +326,13 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # This trajectory is continuing -> Copy data at the end (in the size of # self.shift_before) to the beginning of buffers and erase everything # else. - if not self.buffers[SampleBatch.DONES][-1]: + if not self.buffers[SampleBatch.DONES][0][-1]: # Copy data to beginning of buffer and cut lists. if self.shift_before > 0: for k, data in self.buffers.items(): - self.buffers[k] = data[-self.shift_before:] + # Loop through + for i in range(len(data)): + self.buffers[k][i] = data[i][-self.shift_before:] self.agent_steps = 0 return batch @@ -337,6 +361,8 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: else: self.buffers[col] = [[v for _ in range(shift)] for v in tree.flatten(data)] + # Store an example data struct so we know, how to unflatten + # each data col. self.buffer_structs[col] = data @@ -582,6 +608,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ for k in keys: collector = self.agent_collectors[k] buffers[k] = collector.buffers + # Use one agent's buffer_structs (they should all be the same). buffer_structs = self.agent_collectors[keys[0]].buffer_structs input_dict = {} diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 0a87b8d79930d..e27e6fe29d959 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -863,19 +863,19 @@ def _get_dummy_batch_from_view_requirements( if view_req.shift_from is not None: ret[view_col] = get_dummy_batch_for_space( view_req.space, - batch_size, + batch_size=batch_size, time_size=view_req.shift_to - view_req.shift_from + 1) # Sequence of (probably non-consecutive) indices. elif isinstance(view_req.shift, (list, tuple)): ret[view_col] = get_dummy_batch_for_space( view_req.space, - batch_size, + batch_size=batch_size, time_size=len(view_req.shift)) # Single shift int value. else: if isinstance(view_req.space, gym.spaces.Space): ret[view_col] = get_dummy_batch_for_space( - view_req.space, batch_size, fill_value=0.0) + view_req.space, batch_size=batch_size, fill_value=0.0) else: ret[view_col] = [ view_req.space for _ in range(batch_size) From 3d7c37d0bdf33f23d1190b46e4947176d729a9d1 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 10:04:47 +0200 Subject: [PATCH 25/45] fixes --- rllib/policy/policy.py | 4 +++- rllib/policy/tests/test_policy.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index e27e6fe29d959..c17d23c74d1eb 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -875,7 +875,9 @@ def _get_dummy_batch_from_view_requirements( else: if isinstance(view_req.space, gym.spaces.Space): ret[view_col] = get_dummy_batch_for_space( - view_req.space, batch_size=batch_size, fill_value=0.0) + view_req.space, + batch_size=batch_size, + fill_value=0.0) else: ret[view_col] = [ view_req.space for _ in range(batch_size) diff --git a/rllib/policy/tests/test_policy.py b/rllib/policy/tests/test_policy.py index 15e5efa918d8a..8e1055e72cdf1 100644 --- a/rllib/policy/tests/test_policy.py +++ b/rllib/policy/tests/test_policy.py @@ -8,7 +8,7 @@ class TestPolicy(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - ray.init(local_mode=True)#TODO + ray.init() @classmethod def tearDownClass(cls) -> None: From d9f0af9945cbbab86244ce2133c6aa9b57e84bac Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 10:40:37 +0200 Subject: [PATCH 26/45] fixes. --- .../collectors/simple_list_collector.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 3fb651d659688..a5a2e605a7511 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -80,6 +80,8 @@ def __init__(self, view_reqs, policy): self.buffer_structs: Dict[str, Any] = {} # The episode ID for the agent for which we collect data. self.episode_id = None + # The unroll ID, unique across all rollouts (within a RolloutWorker). + self.unroll_id = None # The simple timestep count for this agent. Gets increased by one # each time a (non-initial!) observation is added. self.agent_steps = 0 @@ -99,9 +101,11 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, init_obs (TensorType): The initial observation tensor (after `env.reset()`). """ - # Store episode ID, which will be constant throughout this + # Store episode ID + unroll ID, which will be constant throughout this # AgentCollector's lifecycle. self.episode_id = episode_id + self.unroll_id = self._next_unroll_id + self._next_unroll_id += 1 if SampleBatch.OBS not in self.buffers: self._build_buffers( @@ -110,6 +114,8 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, SampleBatch.AGENT_INDEX: agent_index, "env_id": env_id, "t": t, + SampleBatch.EPS_ID: self.episode_id, + SampleBatch.UNROLL_ID: self.unroll_id, }) # Append data to existing buffers. @@ -119,6 +125,8 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index) self.buffers["env_id"][0].append(env_id) self.buffers["t"][0].append(t) + self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id) + self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id) def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ None: @@ -130,16 +138,20 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. """ + # Next obs -> obs. assert SampleBatch.OBS not in values values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS] del values[SampleBatch.NEXT_OBS] - # Make sure EPS_ID stays the same for this agent. Usually, it should - # not be part of `values` anyways. + + # Make sure EPS_ID/UNROLL_ID stay the same for this agent. if SampleBatch.EPS_ID in values: - # Make sure eps_id did not change. assert values[SampleBatch.EPS_ID] == self.episode_id - # We'll add the eps_id field later when we build the SampleBatch. del values[SampleBatch.EPS_ID] + self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id) + if SampleBatch.UNROLL_ID in values: + assert values[SampleBatch.UNROLL_ID] == self.unroll_id + del values[SampleBatch.UNROLL_ID] + self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id) for k, v in values.items(): if k not in self.buffers: @@ -312,17 +324,6 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: batch["seq_lens"] = np.array(seq_lens) batch.max_seq_len = max_seq_len - # Add EPS_ID and UNROLL_ID to batch. - batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count) - if SampleBatch.UNROLL_ID not in batch: - # TODO: (sven) Once we have the additional - # model.preprocess_train_batch in place (attention net PR), we - # should not even need UNROLL_ID anymore: - # Add "if SampleBatch.UNROLL_ID in view_requirements:" here. - batch[SampleBatch.UNROLL_ID] = np.repeat( - _AgentCollector._next_unroll_id, batch.count) - _AgentCollector._next_unroll_id += 1 - # This trajectory is continuing -> Copy data at the end (in the size of # self.shift_before) to the beginning of buffers and erase everything # else. From 4e8da9ecd589ef52e849790f9df0dbb1ba6b4e49 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 10:42:05 +0200 Subject: [PATCH 27/45] fix. --- rllib/policy/sample_batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 86a880b9ad1d4..fe85ad76807b7 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -39,6 +39,7 @@ class SampleBatch(dict): PREV_REWARDS = "prev_rewards" DONES = "dones" INFOS = "infos" + SEQ_LENS = "seq_lens" # Extra action fetches keys. ACTION_DIST_INPUTS = "action_dist_inputs" From 59fa8a4c976a3cce0c5a876701d26f12b5117269 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 10:48:04 +0200 Subject: [PATCH 28/45] Add "env_id" and "t" to SampleBatch as consts. --- .../collectors/simple_list_collector.py | 14 ++++++------- rllib/evaluation/sampler.py | 21 +++++++++++-------- rllib/policy/sample_batch.py | 3 +++ 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index c1df01d922b05..effb8be8c566f 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -85,14 +85,14 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, single_row={ SampleBatch.OBS: init_obs, SampleBatch.AGENT_INDEX: agent_index, - "env_id": env_id, - "t": t, + SampleBatch.ENV_ID: env_id, + SampleBatch.T: t, }) self.buffers[SampleBatch.OBS].append(init_obs) self.episode_id = episode_id self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) - self.buffers["env_id"].append(env_id) - self.buffers["t"].append(t) + self.buffers[SampleBatch.ENV_ID].append(env_id) + self.buffers[SampleBatch.T].append(t) def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ None: @@ -279,7 +279,7 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: continue shift = self.shift_before - (1 if col in [ SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, - "env_id", "t" + SampleBatch.ENV_ID, SampleBatch.T ] else 0) # Python primitive, tensor, or dict (e.g. INFOs). self.buffers[col] = [data for _ in range(shift)] @@ -546,8 +546,8 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col delta = -1 if data_col in [ - SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, - SampleBatch.AGENT_INDEX + SampleBatch.OBS, SampleBatch.ENV_ID, SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, SampleBatch.T ] else 0 # Range of shifts, e.g. "-100:0". Note: This includes index 0! if view_req.shift_from is not None: diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index d4f2f50aa2b84..a8b6f10513528 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -833,19 +833,22 @@ def _process_observations( else: # Add actions, rewards, next-obs to collectors. values_dict = { - "t": episode.length - 1, - "env_id": env_id, - "agent_index": episode._agent_index(agent_id), + SampleBatchType.T: episode.length - 1, + SampleBatchType.ENV_ID: env_id, + SampleBatchType.AGENT_INDEX: episode._agent_index( + agent_id), # Action (slot 0) taken at timestep t. - "actions": episode.last_action_for(agent_id), + SampleBatchType.ACTIONS: episode.last_action_for(agent_id), # Reward received after taking a at timestep t. - "rewards": rewards[env_id].get(agent_id, 0.0), + SampleBatchType.REWARDS: rewards[env_id].get( + agent_id, 0.0), # After taking action=a, did we reach terminal? - "dones": (False if (no_done_at_end - or (hit_horizon and soft_horizon)) else - agent_done), + SampleBatchType.DONES: (False if + (no_done_at_end + or (hit_horizon and soft_horizon)) + else agent_done), # Next observation. - "new_obs": filtered_obs, + SampleBatchType.NEXT_OBS: filtered_obs, } # Add extra-action-fetches to collectors. pol = worker.policy_map[policy_id] diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index fe85ad76807b7..23e87c520a19e 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -40,6 +40,7 @@ class SampleBatch(dict): DONES = "dones" INFOS = "infos" SEQ_LENS = "seq_lens" + T = "t" # Extra action fetches keys. ACTION_DIST_INPUTS = "action_dist_inputs" @@ -48,6 +49,8 @@ class SampleBatch(dict): # Uniquely identifies an episode. EPS_ID = "eps_id" + # An env ID (e.g. the index for a vectorized sub-env). + ENV_ID = "env_id" # Uniquely identifies a sample batch. This is important to distinguish RNN # sequences from the same episode when multiple sample batches are From 78e147280b2f7dfe1c866e7bf8fd4304b4dcfe6d Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 11:26:07 +0200 Subject: [PATCH 29/45] Fix. --- python/ray/tests/test_advanced.py | 2 +- rllib/evaluation/sampler.py | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index e9c85c3b73612..aa403c31f433e 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -15,7 +15,7 @@ import ray._private.profiling as profiling from ray._private.test_utils import (client_test_enabled, - RayTestTimeoutException, SignalActor) + RayTestTimeoutException) if client_test_enabled(): from ray.util.client import ray diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index a8b6f10513528..e2e682e6adc7f 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -24,6 +24,7 @@ from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.offline import InputReader from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning @@ -833,22 +834,20 @@ def _process_observations( else: # Add actions, rewards, next-obs to collectors. values_dict = { - SampleBatchType.T: episode.length - 1, - SampleBatchType.ENV_ID: env_id, - SampleBatchType.AGENT_INDEX: episode._agent_index( - agent_id), + SampleBatch.T: episode.length - 1, + SampleBatch.ENV_ID: env_id, + SampleBatch.AGENT_INDEX: episode._agent_index(agent_id), # Action (slot 0) taken at timestep t. - SampleBatchType.ACTIONS: episode.last_action_for(agent_id), + SampleBatch.ACTIONS: episode.last_action_for(agent_id), # Reward received after taking a at timestep t. - SampleBatchType.REWARDS: rewards[env_id].get( - agent_id, 0.0), + SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0), # After taking action=a, did we reach terminal? - SampleBatchType.DONES: (False if - (no_done_at_end - or (hit_horizon and soft_horizon)) - else agent_done), + SampleBatch.DONES: (False + if (no_done_at_end + or (hit_horizon and soft_horizon)) + else agent_done), # Next observation. - SampleBatchType.NEXT_OBS: filtered_obs, + SampleBatch.NEXT_OBS: filtered_obs, } # Add extra-action-fetches to collectors. pol = worker.policy_map[policy_id] From 8fcb9cfc6d1ea3a256c51e36604b66c2a41701fe Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 13:22:23 +0200 Subject: [PATCH 30/45] wip. --- rllib/evaluation/collectors/simple_list_collector.py | 10 +++++----- rllib/policy/dynamic_tf_policy.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 723c9d141fd7a..99bd6bd3b5baf 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -348,11 +348,6 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: for col, data in single_row.items(): if col in self.buffers: continue - shift = self.shift_before - (1 if col in [ - SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, - SampleBatch.ENV_ID, SampleBatch.T - ] else 0) - # Store all data as flattened lists, except INFOS and state-out # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which @@ -360,6 +355,11 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: if col == SampleBatch.INFOS or col.startswith("state_out_"): self.buffers[col] = [[data]] else: + shift = self.shift_before - (1 if col in [ + SampleBatch.OBS, SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, SampleBatch.ENV_ID, + SampleBatch.T, SampleBatch.UNROLL_ID + ] else 0) self.buffers[col] = [[v for _ in range(shift)] for v in tree.flatten(data)] # Store an example data struct so we know, how to unflatten diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index a0656e70a9e84..6a6863d7c0a3f 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -508,8 +508,8 @@ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): if batch_size >= len(self._loaded_single_cpu_batch): sliced_batch = self._loaded_single_cpu_batch else: - sliced_batch = self._loaded_single_cpu_batch[offset:offset + - batch_size] + sliced_batch = self._loaded_single_cpu_batch.slice(offset, offset + + batch_size) return self.learn_on_batch(sliced_batch) return self.multi_gpu_tower_stacks[buffer_index].optimize( From 392dd1e8d2a4abd5fd6b16fcb1d439d7ba1d065f Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 18:58:36 +0200 Subject: [PATCH 31/45] merge --- rllib/execution/replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 07c46e97a6b2b..4abb3a0ac2adf 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -402,7 +402,7 @@ def add_batch(self, batch: SampleBatchType) -> None: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. - if "weights" in time_slice: + if "weights" in time_slice and time_slice["weights"]: weight = np.mean(time_slice["weights"]) else: weight = None From c3d9d5a8b9dc32338f86ad312ec80297ae7c7edf Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 19:00:03 +0200 Subject: [PATCH 32/45] LINT. --- rllib/evaluation/collectors/simple_list_collector.py | 4 ++-- rllib/policy/dynamic_tf_policy.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 99bd6bd3b5baf..fbe92ee3c3bab 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -357,8 +357,8 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: else: shift = self.shift_before - (1 if col in [ SampleBatch.OBS, SampleBatch.EPS_ID, - SampleBatch.AGENT_INDEX, SampleBatch.ENV_ID, - SampleBatch.T, SampleBatch.UNROLL_ID + SampleBatch.AGENT_INDEX, SampleBatch.ENV_ID, SampleBatch.T, + SampleBatch.UNROLL_ID ] else 0) self.buffers[col] = [[v for _ in range(shift)] for v in tree.flatten(data)] diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 6a6863d7c0a3f..d130a7232caef 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -508,8 +508,8 @@ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): if batch_size >= len(self._loaded_single_cpu_batch): sliced_batch = self._loaded_single_cpu_batch else: - sliced_batch = self._loaded_single_cpu_batch.slice(offset, offset + - batch_size) + sliced_batch = self._loaded_single_cpu_batch.slice( + offset, offset + batch_size) return self.learn_on_batch(sliced_batch) return self.multi_gpu_tower_stacks[buffer_index].optimize( From a4e97440c076c443363d889ae88f95874d3c03fa Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 19:39:22 +0200 Subject: [PATCH 33/45] LINT. --- rllib/evaluation/collectors/simple_list_collector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 32cccbe4d2311..d862f737cd7a3 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -650,7 +650,8 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ space = view_req.space fill_value = get_dummy_batch_for_space( - space, batch_size=0, + space, + batch_size=0, ) if isinstance(space, Space) else space self.agent_collectors[k]._build_buffers({ From 7959e64a0acb819600ba9f2647b86412cffe45cd Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 19:56:38 +0200 Subject: [PATCH 34/45] wip. --- rllib/BUILD | 20 ++++++++++++++++++++ rllib/policy/sample_batch.py | 6 +++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 6286c40546808..dc59aa2789575 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2354,6 +2354,26 @@ py_test( srcs = ["examples/pettingzoo_env.py"], ) +py_test( + name = "examples/preprocessing_disabled_tf", + main = "examples/preprocessing_disabled.py", + tags = ["team:ml", "examples", "examples_P"], + size = "small", + srcs = ["examples/preprocessing_disabled.py"], + args = ["--stop-iters=2"] +) + +# Not supported for torch yet (complex-input model needs +# to accept dict spaces as well) +# py_test( +# name = "examples/preprocessing_disabled_torch", +# main = "examples/preprocessing_disabled.py", +# tags = ["team:ml", "examples", "examples_P"], +# size = "small", +# srcs = ["examples/preprocessing_disabled.py"], +# args = ["--framework=torch", "--stop-iters=2"] +# ) + py_test( name = "examples/remote_envs_with_inference_done_on_main_node_tf", main = "examples/remote_envs_with_inference_done_on_main_node.py", diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index d4319c0a6b6d3..4492c171acb24 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -502,10 +502,10 @@ def slice(self, start: int, end: int, state_start=None, ) else: return SampleBatch( - {k: v[start:end] - for k, v in self.items()}, + tree.map_structure(lambda value: value[start:end], self), _is_training=self.is_training, - _time_major=self.time_major) + _time_major=self.time_major, + ) @PublicAPI def timeslices(self, From cf3c9dca6befe45da38841caeafb0ecaf58d244f Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 20 Aug 2021 19:58:03 +0200 Subject: [PATCH 35/45] wip. --- rllib/examples/preprocessing_disabled.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rllib/examples/preprocessing_disabled.py b/rllib/examples/preprocessing_disabled.py index d492b412d4a80..43613ca13fd0c 100644 --- a/rllib/examples/preprocessing_disabled.py +++ b/rllib/examples/preprocessing_disabled.py @@ -73,6 +73,10 @@ def get_cli_args(): if __name__ == "__main__": args = get_cli_args() + assert args.framework == "tf",\ + "No-preprocessing only working for tf so far! Complex input " \ + "model must be changed to accept Dict spaces as well (besides " \ + "Tuples)." ray.init(local_mode=args.local_mode) From 90083b6eb4bc7688db9ee956955b6146a518f941 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 21 Aug 2021 17:03:44 +0200 Subject: [PATCH 36/45] fix. --- .../collectors/simple_list_collector.py | 86 +++++++++---------- 1 file changed, 41 insertions(+), 45 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index d862f737cd7a3..0522f3623225d 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -348,18 +348,19 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: for col, data in single_row.items(): if col in self.buffers: continue + + shift = self.shift_before - (1 if col in [ + SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, + SampleBatch.ENV_ID, SampleBatch.T, SampleBatch.UNROLL_ID + ] else 0) + # Store all data as flattened lists, except INFOS and state-out # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which # could be custom dicts as well). if col == SampleBatch.INFOS or col.startswith("state_out_"): - self.buffers[col] = [[data]] + self.buffers[col] = [[data for _ in range(shift)]] else: - shift = self.shift_before - (1 if col in [ - SampleBatch.OBS, SampleBatch.EPS_ID, - SampleBatch.AGENT_INDEX, SampleBatch.ENV_ID, SampleBatch.T, - SampleBatch.UNROLL_ID - ] else 0) self.buffers[col] = [[v for _ in range(shift)] for v in tree.flatten(data)] # Store an example data struct so we know, how to unflatten @@ -635,51 +636,46 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # Loop through agents and add up their data (batch). data = None for k in keys: - if data_col == SampleBatch.EPS_ID: - if data is None: - data = [[]] - data[0].append(self.agent_collectors[k].episode_id) - else: - # Buffer for the data does not exist yet: Create dummy - # (zero) data. - if data_col not in buffers[k]: - if view_req.data_col is not None: - space = policy.view_requirements[ - view_req.data_col].space - else: - space = view_req.space + # Buffer for the data does not exist yet: Create dummy + # (zero) data. + if data_col not in buffers[k]: + if view_req.data_col is not None: + space = policy.view_requirements[ + view_req.data_col].space + else: + space = view_req.space + if isinstance(space, Space): fill_value = get_dummy_batch_for_space( space, batch_size=0, - ) if isinstance(space, Space) else space - - self.agent_collectors[k]._build_buffers({ - data_col: fill_value - }) - - if data is None: - data = [ - [] for _ in range(len(buffers[keys[0]][data_col])) - ] - - # `shift_from` and `shift_to` are defined: User wants a - # view with some time-range. - if isinstance(time_indices, tuple): - # `shift_to` == -1: Until the end (including(!) the - # last item). - if time_indices[1] == -1: - for d, b in zip(data, buffers[k][data_col]): - d.append(b[time_indices[0]:]) - # `shift_to` != -1: "Normal" range. - else: - for d, b in zip(data, buffers[k][data_col]): - d.append( - b[time_indices[0]:time_indices[1] + 1]) - # Single index. + ) + else: + fill_value = space + + self.agent_collectors[k]._build_buffers({ + data_col: fill_value + }) + + if data is None: + data = [[] for _ in range(len(buffers[keys[0]][data_col]))] + + # `shift_from` and `shift_to` are defined: User wants a + # view with some time-range. + if isinstance(time_indices, tuple): + # `shift_to` == -1: Until the end (including(!) the + # last item). + if time_indices[1] == -1: + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices[0]:]) + # `shift_to` != -1: "Normal" range. else: for d, b in zip(data, buffers[k][data_col]): - d.append(b[time_indices]) + d.append(b[time_indices[0]:time_indices[1] + 1]) + # Single index. + else: + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices]) np_data = [np.array(d) for d in data] if data_col in buffer_structs: From 4825e27a7e7b55882ecda06e451ba4db457b0cd3 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 21 Aug 2021 17:31:09 +0200 Subject: [PATCH 37/45] wip. --- rllib/BUILD | 18 ++++++------ rllib/examples/preprocessing_disabled.py | 10 ------- rllib/models/catalog.py | 5 ++-- rllib/models/modelv2.py | 7 ++++- rllib/models/preprocessors.py | 3 +- rllib/models/torch/complex_input_net.py | 35 ++++++++++++++++-------- 6 files changed, 41 insertions(+), 37 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index dc59aa2789575..b03d0fd807be5 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2363,16 +2363,14 @@ py_test( args = ["--stop-iters=2"] ) -# Not supported for torch yet (complex-input model needs -# to accept dict spaces as well) -# py_test( -# name = "examples/preprocessing_disabled_torch", -# main = "examples/preprocessing_disabled.py", -# tags = ["team:ml", "examples", "examples_P"], -# size = "small", -# srcs = ["examples/preprocessing_disabled.py"], -# args = ["--framework=torch", "--stop-iters=2"] -# ) +py_test( + name = "examples/preprocessing_disabled_torch", + main = "examples/preprocessing_disabled.py", + tags = ["team:ml", "examples", "examples_P"], + size = "small", + srcs = ["examples/preprocessing_disabled.py"], + args = ["--framework=torch", "--stop-iters=2"] +) py_test( name = "examples/remote_envs_with_inference_done_on_main_node_tf", diff --git a/rllib/examples/preprocessing_disabled.py b/rllib/examples/preprocessing_disabled.py index 43613ca13fd0c..bc3a50a6cc8ea 100644 --- a/rllib/examples/preprocessing_disabled.py +++ b/rllib/examples/preprocessing_disabled.py @@ -21,12 +21,6 @@ def get_cli_args(): """Create CLI parser and return parsed arguments""" parser = argparse.ArgumentParser() - # example-specific args - parser.add_argument( - "--no-attention", - action="store_true", - help="Do NOT use attention. For comparison: The agent will not learn.") - # general args parser.add_argument( "--run", default="PPO", help="The RLlib-registered algorithm to use.") @@ -73,10 +67,6 @@ def get_cli_args(): if __name__ == "__main__": args = get_cli_args() - assert args.framework == "tf",\ - "No-preprocessing only working for tf so far! Complex input " \ - "model must be changed to accept Dict spaces as well (besides " \ - "Tuples)." ray.init(local_mode=args.local_mode) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index af86f06f1ac39..a54b73f63d6c0 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -670,7 +670,7 @@ def get_preprocessor(env: gym.Env, options) @staticmethod - @Deprecated(error=False) + @DeveloperAPI def get_preprocessor_for_space(observation_space: gym.Space, options: dict = None) -> Preprocessor: """Returns a suitable preprocessor for the given observation space. @@ -882,7 +882,8 @@ def _validate_config(config: ModelConfigDict, framework: str) -> None: if config.get("custom_preprocessor") is not None: deprecation_warning( old="model.custom_preprocessor", - new="gym.ObservationWrapper around your env", + new="gym.ObservationWrapper around your env or handle complex " + "inputs inside your Model", error=False, ) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 0a48de46ee4a8..68ab868b74d29 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -217,6 +217,8 @@ def __call__( else: restored = input_dict.copy() + # Input to this Model went through a Preprocessor. + # Generate extra keys: "obs_flat" vs "obs". if hasattr(self.obs_space, "original_space"): restored["obs"] = restore_original_dimensions( input_dict["obs"], self.obs_space, self.framework) @@ -228,8 +230,11 @@ def __call__( restored["obs_flat"] = input_dict["obs"] except AttributeError: restored["obs_flat"] = input_dict["obs"] + # No Preprocessor used: `config.preprocessor_pref`=None. # TODO: This is unnecessary for when no preprocessor is used. - # Obs are not flat then anymore. + # Obs are not flat then anymore. We keep this here for + # backward-compatibility until Preprocessors have been fully + # deprecated. else: restored["obs_flat"] = input_dict["obs"] diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 4206ad90cdef7..b2a375403c4e2 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -4,7 +4,7 @@ import gym from typing import Any, List -from ray.rllib.utils.annotations import Deprecated, override, PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.spaces.repeated import Repeated from ray.rllib.utils.typing import TensorType from ray.rllib.utils.images import resize @@ -177,7 +177,6 @@ def write(self, observation: TensorType, array: np.ndarray, array[offset:offset + self.size] = self.transform(observation) -@Deprecated(error=False) class NoPreprocessor(Preprocessor): @override(Preprocessor) def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index 4fb71dabeed7a..00ed3c21b938c 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -1,16 +1,19 @@ -from gym.spaces import Box, Discrete, Tuple +from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple import numpy as np +import tree # pip install dm_tree # TODO (sven): add IMPALA-style option. # from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet from ray.rllib.models.torch.misc import normc_initializer as \ torch_normc_initializer, SlimFC from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.utils import get_filter_config +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.spaces.space_utils import flatten_space from ray.rllib.utils.torch_ops import one_hot torch, nn = try_import_torch() @@ -32,16 +35,17 @@ class ComplexInputNetwork(TorchModelV2, nn.Module): def __init__(self, obs_space, action_space, num_outputs, model_config, name): - # TODO: (sven) Support Dicts as well. self.original_space = obs_space.original_space if \ hasattr(obs_space, "original_space") else obs_space - assert isinstance(self.original_space, (Tuple)), \ - "`obs_space.original_space` must be Tuple!" + assert isinstance(self.original_space, (Dict, Tuple)), \ + "`obs_space.original_space` must be [Dict|Tuple]!" nn.Module.__init__(self) TorchModelV2.__init__(self, self.original_space, action_space, num_outputs, model_config, name) + self.flattened_input_space = flatten_space(self.original_space) + # Atari type CNNs or IMPALA type CNNs (with residual layers)? # self.cnn_type = self.model_config["custom_model_config"].get( # "conv_type", "atari") @@ -51,7 +55,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, self.one_hot = {} self.flatten = {} concat_size = 0 - for i, component in enumerate(self.original_space): + for i, component in enumerate(self.flattened_input_space): # Image space. if len(component.shape) == 3: config = { @@ -81,11 +85,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, concat_size += cnn.num_outputs self.cnns[i] = cnn self.add_module("cnn_{}".format(i), cnn) - # Discrete inputs -> One-hot encode. + # Discrete|MultiDiscrete inputs -> One-hot encode. elif isinstance(component, Discrete): self.one_hot[i] = True concat_size += component.n - # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers). + elif isinstance(component, MultiDiscrete): + self.one_hot[i] = True + concat_size += sum(component.nvec) # Everything else (1D Box). else: self.flatten[i] = int(np.product(component.shape)) @@ -131,20 +137,25 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, @override(ModelV2) def forward(self, input_dict, state, seq_lens): + if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: + orig_obs = input_dict[SampleBatch.OBS] + else: + orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], + self.obs_space, "tf") # Push image observations through our CNNs. outs = [] - for i, component in enumerate(input_dict["obs"]): + for i, component in enumerate(tree.flatten(orig_obs)): if i in self.cnns: - cnn_out, _ = self.cnns[i]({"obs": component}) + cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: - outs.append(one_hot(component, self.original_space.spaces[i])) + outs.append(one_hot(component, self.flattened_input_space[i])) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. out = torch.cat(outs, dim=1) # Push through (optional) FC-stack (this may be an empty stack). - out, _ = self.post_fc_stack({"obs": out}, [], None) + out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None) # No logits/value branches. if self.logits_layer is None: From 6c0ad15e85c70b6e6dd1ca66e19e3196a3c1661e Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 23 Aug 2021 10:10:50 +0200 Subject: [PATCH 38/45] wip. --- rllib/agents/dqn/dqn_tf_policy.py | 2 ++ rllib/policy/policy.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 345d30d3e4c99..ee83d2ec72fb6 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -416,6 +416,8 @@ def postprocess_nstep_and_prio(policy: Policy, batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES]) + # Create dummy prio-weights (1.0) in case we don't have any in + # the batch. if PRIO_WEIGHTS not in batch: batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS]) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 52af74c2c7442..4666f25959770 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -804,7 +804,8 @@ def _initialize_loss_from_dummy_batch( self._dummy_batch.added_keys for key in all_accessed_keys: if key not in self.view_requirements: - self.view_requirements[key] = ViewRequirement() + self.view_requirements[key] = ViewRequirement( + used_for_compute_actions=False) if self._loss: # Tag those only needed for post-processing (with some # exceptions). From 3b4f64496b6d62e73ccc71379bfa8a902fde59bd Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 23 Aug 2021 14:45:16 +0200 Subject: [PATCH 39/45] wip. --- rllib/execution/replay_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 4abb3a0ac2adf..aa65188ed0aa3 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -402,7 +402,8 @@ def add_batch(self, batch: SampleBatchType) -> None: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. - if "weights" in time_slice and time_slice["weights"]: + if "weights" in time_slice and \ + len(time_slice["weights"]): weight = np.mean(time_slice["weights"]) else: weight = None From 60aa649067762731882209bff738313492d7f654 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 27 Aug 2021 08:34:14 +0200 Subject: [PATCH 40/45] wip --- .../collectors/simple_list_collector.py | 26 ++++++++++++++----- rllib/evaluation/tests/test_rollout_worker.py | 11 +++----- rllib/examples/two_step_game.py | 2 +- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 1a98fd14e65bb..110a652e59b87 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -104,8 +104,9 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, # Store episode ID + unroll ID, which will be constant throughout this # AgentCollector's lifecycle. self.episode_id = episode_id - self.unroll_id = self._next_unroll_id - self._next_unroll_id += 1 + if self.unroll_id is None: + self.unroll_id = self._next_unroll_id + self._next_unroll_id += 1 if SampleBatch.OBS not in self.buffers: self._build_buffers( @@ -137,6 +138,9 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ row) to be added to buffer. Must contain keys: SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. """ + if self.unroll_id is None: + self.unroll_id = self._next_unroll_id + self._next_unroll_id += 1 # Next obs -> obs. assert SampleBatch.OBS not in values @@ -156,15 +160,20 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) - # Do not flatten infos or state_out_... (their values may be - # structs that change from timestep to timestep). - if k == SampleBatch.INFOS or k.startswith("state_out_"): + # Do not flatten infos, state_out_ and actions. + # Infos/state-outs may be structs that change from timestep to + # timestep. Actions - on the other hand - are already flattened + # in the sampler. + if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"): self.buffers[k][0].append(v) # Flatten all other columns. else: flattened = tree.flatten(v) for i, sub_list in enumerate(self.buffers[k]): - sub_list.append(flattened[i]) + try:#TODO + sub_list.append(flattened[i]) + except Exception as e: + raise e self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: @@ -336,6 +345,9 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: self.buffers[k][i] = data[i][-self.shift_before:] self.agent_steps = 0 + # Reset our unroll_id. + self.unroll_id = None + return batch def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: @@ -358,7 +370,7 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which # could be custom dicts as well). - if col == SampleBatch.INFOS or col.startswith("state_out_"): + if col in [SampleBatch.INFOS, SampleBatch.ACTIONS] or col.startswith("state_out_"): self.buffers[col] = [[data for _ in range(shift)]] else: self.buffers[col] = [[v for _ in range(shift)] diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 913fce2fe632a..93d82f7dfcab7 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -113,21 +113,18 @@ def to_prev(vec): ev.stop() def test_batch_ids(self): + fragment_len = 20 ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy_spec=MockPolicy, - rollout_fragment_length=1) + rollout_fragment_length=fragment_len) batch1 = ev.sample() batch2 = ev.sample() - self.assertEqual(len(set(batch1["unroll_id"])), 1) - self.assertEqual(len(set(batch2["unroll_id"])), 1) - self.assertEqual( - len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2) + check(batch1[SampleBatch.UNROLL_ID], [0] * fragment_len) + check(batch2[SampleBatch.UNROLL_ID], [1] * fragment_len) ev.stop() def test_global_vars_update(self): - # Allow for Unittest run. - ray.init(num_cpus=5, ignore_reinit_error=True) for fw in framework_iterator(frameworks=("tf2", "tf")): agent = A2CTrainer( env="CartPole-v0", diff --git a/rllib/examples/two_step_game.py b/rllib/examples/two_step_game.py index e3c83bbde432f..97c17e65d415d 100644 --- a/rllib/examples/two_step_game.py +++ b/rllib/examples/two_step_game.py @@ -56,7 +56,7 @@ if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None) + ray.init(num_cpus=args.num_cpus or None, local_mode=True)#TODO grouping = { "group_1": [0, 1], From 6d822dbc355af5710f562a3c98f3692c459fbeef Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 27 Aug 2021 09:14:46 +0200 Subject: [PATCH 41/45] wip --- rllib/agents/sac/tests/test_sac.py | 2 +- .../evaluation/collectors/simple_list_collector.py | 8 ++++---- rllib/evaluation/tests/test_rollout_worker.py | 14 +++++++++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index b84f9ebef2172..9b096545e6ba0 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -59,7 +59,7 @@ class TestSAC(unittest.TestCase): def setUpClass(cls) -> None: np.random.seed(42) torch.manual_seed(42) - ray.init() + ray.init(local_mode=True)#TODO @classmethod def tearDownClass(cls) -> None: diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 110a652e59b87..d5eee6a559b54 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -105,8 +105,8 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int, # AgentCollector's lifecycle. self.episode_id = episode_id if self.unroll_id is None: - self.unroll_id = self._next_unroll_id - self._next_unroll_id += 1 + self.unroll_id = _AgentCollector._next_unroll_id + _AgentCollector._next_unroll_id += 1 if SampleBatch.OBS not in self.buffers: self._build_buffers( @@ -139,8 +139,8 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. """ if self.unroll_id is None: - self.unroll_id = self._next_unroll_id - self._next_unroll_id += 1 + self.unroll_id = _AgentCollector._next_unroll_id + _AgentCollector._next_unroll_id += 1 # Next obs -> obs. assert SampleBatch.OBS not in values diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 93d82f7dfcab7..efa42c5bf8483 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -113,18 +113,26 @@ def to_prev(vec): ev.stop() def test_batch_ids(self): - fragment_len = 20 + fragment_len = 100 ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy_spec=MockPolicy, rollout_fragment_length=fragment_len) batch1 = ev.sample() batch2 = ev.sample() - check(batch1[SampleBatch.UNROLL_ID], [0] * fragment_len) - check(batch2[SampleBatch.UNROLL_ID], [1] * fragment_len) + unroll_ids_1 = set(batch1["unroll_id"]) + unroll_ids_2 = set(batch2["unroll_id"]) + # Assert no overlap of unroll IDs between sample() calls. + self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1)) + # CartPole episodes should be short initially: Expect more than one + # unroll ID in each batch. + self.assertTrue(len(unroll_ids_1) > 1) + self.assertTrue(len(unroll_ids_2) > 1) ev.stop() def test_global_vars_update(self): + # Allow for Unittest run. + ray.init(num_cpus=5, ignore_reinit_error=True) for fw in framework_iterator(frameworks=("tf2", "tf")): agent = A2CTrainer( env="CartPole-v0", From 29695e07f05a05870645e199b80bafd7676f0d7c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 27 Aug 2021 09:40:14 +0200 Subject: [PATCH 42/45] wip --- rllib/BUILD | 4 ++-- rllib/agents/sac/tests/test_sac.py | 2 +- rllib/evaluation/collectors/simple_list_collector.py | 11 +++++------ rllib/examples/two_step_game.py | 2 +- rllib/models/tf/complex_input_net.py | 6 +++++- rllib/models/torch/complex_input_net.py | 6 +++++- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index b03d0fd807be5..f4f78e5202942 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2358,7 +2358,7 @@ py_test( name = "examples/preprocessing_disabled_tf", main = "examples/preprocessing_disabled.py", tags = ["team:ml", "examples", "examples_P"], - size = "small", + size = "medium", srcs = ["examples/preprocessing_disabled.py"], args = ["--stop-iters=2"] ) @@ -2367,7 +2367,7 @@ py_test( name = "examples/preprocessing_disabled_torch", main = "examples/preprocessing_disabled.py", tags = ["team:ml", "examples", "examples_P"], - size = "small", + size = "medium", srcs = ["examples/preprocessing_disabled.py"], args = ["--framework=torch", "--stop-iters=2"] ) diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 9b096545e6ba0..b84f9ebef2172 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -59,7 +59,7 @@ class TestSAC(unittest.TestCase): def setUpClass(cls) -> None: np.random.seed(42) torch.manual_seed(42) - ray.init(local_mode=True)#TODO + ray.init() @classmethod def tearDownClass(cls) -> None: diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index d5eee6a559b54..2f687916a2b45 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -164,16 +164,14 @@ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ # Infos/state-outs may be structs that change from timestep to # timestep. Actions - on the other hand - are already flattened # in the sampler. - if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"): + if k in [SampleBatch.INFOS, SampleBatch.ACTIONS + ] or k.startswith("state_out_"): self.buffers[k][0].append(v) # Flatten all other columns. else: flattened = tree.flatten(v) for i, sub_list in enumerate(self.buffers[k]): - try:#TODO - sub_list.append(flattened[i]) - except Exception as e: - raise e + sub_list.append(flattened[i]) self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: @@ -370,7 +368,8 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which # could be custom dicts as well). - if col in [SampleBatch.INFOS, SampleBatch.ACTIONS] or col.startswith("state_out_"): + if col in [SampleBatch.INFOS, SampleBatch.ACTIONS + ] or col.startswith("state_out_"): self.buffers[col] = [[data for _ in range(shift)]] else: self.buffers[col] = [[v for _ in range(shift)] diff --git a/rllib/examples/two_step_game.py b/rllib/examples/two_step_game.py index 97c17e65d415d..e3c83bbde432f 100644 --- a/rllib/examples/two_step_game.py +++ b/rllib/examples/two_step_game.py @@ -56,7 +56,7 @@ if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None, local_mode=True)#TODO + ray.init(num_cpus=args.num_cpus or None) grouping = { "group_1": [0, 1], diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index 31679e44ee4da..45dc01ab1f1e8 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -133,7 +133,11 @@ def forward(self, input_dict, state, seq_lens): cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: - outs.append(one_hot(component, self.flattened_input_space[i])) + if component.dtype in [tf.int32, tf.int64, tf.uint8]: + outs.append( + one_hot(component, self.flattened_input_space[i])) + else: + outs.append(component) else: outs.append( tf.cast( diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index 00ed3c21b938c..b795e4d5485c3 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -149,7 +149,11 @@ def forward(self, input_dict, state, seq_lens): cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: - outs.append(one_hot(component, self.flattened_input_space[i])) + if component.dtype in [torch.int32, torch.int64, torch.uint8]: + outs.append( + one_hot(component, self.flattened_input_space[i])) + else: + outs.append(component) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. From fb50d81471d7ce350744d94680361f6bfd26e0c1 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 27 Aug 2021 10:39:10 +0200 Subject: [PATCH 43/45] wip --- rllib/agents/trainer.py | 6 ++++++ rllib/models/catalog.py | 5 +++++ rllib/models/modelv2.py | 19 ++++++++++--------- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 0d67c6aac3fd8..b377d3896877d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1476,6 +1476,12 @@ def _validate_config(config: PartialTrainerConfigDict, config["input_evaluation"])) # Check model config. + # If no preprocessing, propagate into model's config as well + # (so model will know, whether inputs are preprocessed or not). + if config["preprocessor_pref"] is None: + model_config["_no_preprocessor"] = True + + # Prev_a/r settings. prev_a_r = model_config.get("lstm_use_prev_action_reward", DEPRECATED_VALUE) if prev_a_r != DEPRECATED_VALUE: diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index a54b73f63d6c0..d099683ef1e1e 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -44,6 +44,11 @@ # 2) fully connected and CNN default networks as well as # auto-wrapped LSTM- and attention nets. "_use_default_native_models": False, + # Experimental flag. + # If True, user specified no preprocessor to be created + # (via config.preprocessor_pref=None). If True, observations will arrive + # in model as they are returned by the env. + "_no_preprocessing": False, # === Built-in options === # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 68ab868b74d29..fcc7ae75509ae 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -217,9 +217,17 @@ def __call__( else: restored = input_dict.copy() + # No Preprocessor used: `config.preprocessor_pref`=None. + # TODO: This is unnecessary for when no preprocessor is used. + # Obs are not flat then anymore. We keep this here for + # backward-compatibility until Preprocessors have been fully + # deprecated. + if self.model_config.get("_no_preprocessing"): + restored["obs_flat"] = input_dict["obs"] # Input to this Model went through a Preprocessor. - # Generate extra keys: "obs_flat" vs "obs". - if hasattr(self.obs_space, "original_space"): + # Generate extra keys: "obs_flat" (vs "obs", which will hold the + # original obs). + else: restored["obs"] = restore_original_dimensions( input_dict["obs"], self.obs_space, self.framework) try: @@ -230,13 +238,6 @@ def __call__( restored["obs_flat"] = input_dict["obs"] except AttributeError: restored["obs_flat"] = input_dict["obs"] - # No Preprocessor used: `config.preprocessor_pref`=None. - # TODO: This is unnecessary for when no preprocessor is used. - # Obs are not flat then anymore. We keep this here for - # backward-compatibility until Preprocessors have been fully - # deprecated. - else: - restored["obs_flat"] = input_dict["obs"] with self.context(): res = self.forward(restored, state or [], seq_lens) From fa8f54f7b15425a78d9bc37c38319118a9267bf2 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 27 Aug 2021 14:39:25 +0200 Subject: [PATCH 44/45] wip --- rllib/examples/custom_metrics_and_callbacks.py | 2 +- rllib/execution/rollout_ops.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index c86033cf05ed4..67dc6b7644cef 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -56,7 +56,7 @@ def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, env_index: int, **kwargs): # Make sure this episode is really done. assert episode.batch_builder.policy_collectors[ - "default_policy"].buffers["dones"][-1], \ + "default_policy"].batches[-1]["dones"][-1], \ "ERROR: `on_episode_end()` should only be called " \ "after episode is done!" pole_angle = np.mean(episode.user_data["pole_angles"]) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 364a814c8c996..e408ca5416d3d 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -164,12 +164,11 @@ def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) + # Count by env_steps (one env step contains N agent steps). if self.count_steps_by == "env_steps": self.count += batch.count + # Count by individual agent steps. else: - assert isinstance(batch, MultiAgentBatch), \ - "`count_steps_by=agent_steps` only allowed in multi-agent " \ - "environments!" self.count += batch.agent_steps() if self.count >= self.min_batch_size: From d294729a856ab05a5865fb3680f7ff99dd48b173 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 27 Aug 2021 14:40:04 +0200 Subject: [PATCH 45/45] wip --- rllib/execution/rollout_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index e408ca5416d3d..364a814c8c996 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -164,11 +164,12 @@ def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) - # Count by env_steps (one env step contains N agent steps). if self.count_steps_by == "env_steps": self.count += batch.count - # Count by individual agent steps. else: + assert isinstance(batch, MultiAgentBatch), \ + "`count_steps_by=agent_steps` only allowed in multi-agent " \ + "environments!" self.count += batch.agent_steps() if self.count >= self.min_batch_size: