From 42ce732d09a89da56c8bfeb7b035cebbfdfa4033 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 17 Jan 2025 13:29:45 +0000 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- torchrl/envs/transforms/transforms.py | 78 ++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 93f935a49fd..50d2762ad0b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4426,10 +4426,12 @@ class UnaryTransform(Transform): Args: in_keys (sequence of NestedKey): the keys of inputs to the unary operation. out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. - fn (Callable): the function to use as the unary operation. If it accepts - a non-tensor input, it must also accept ``None``. + in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call. + out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call. Keyword Args: + fn (Callable): the function to use as the unary operation. If it accepts + a non-tensor input, it must also accept ``None``. use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4500,11 +4502,18 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], - fn: Callable, + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, *, + fn: Callable, use_raw_nontensor: bool = False, ): - super().__init__(in_keys=in_keys, out_keys=out_keys) + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ) self._fn = fn self._use_raw_nontensor = use_raw_nontensor @@ -4519,6 +4528,17 @@ def _apply_transform(self, value): value = value.tolist() return self._fn(value) + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + if not self._use_raw_nontensor: + if isinstance(state, NonTensorData): + if state.dim() == 0: + state = state.get("data") + else: + state = state.tolist() + elif isinstance(state, NonTensorStack): + state = state.tolist() + return self._fn(state) + def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: @@ -4526,6 +4546,32 @@ def _reset( tensordict_reset = self._call(tensordict_reset) return tensordict_reset + def transform_input_spec(self, input_spec: Composite) -> Composite: + input_spec = input_spec.clone() + + # Make a generic input from the spec, call the transform with that + # input, and then generate the output spec from the output. + zero_input_ = input_spec.zero() + test_input = zero_input_["full_action_spec"].update( + zero_input_["full_state_spec"] + ) + test_output = self.inv(test_input) + test_input_spec = make_composite_from_td( + test_output, unsqueeze_null_shapes=False + ) + + input_spec["full_action_spec"] = self.transform_action_spec( + input_spec["full_action_spec"], + test_input_spec, + ) + if "full_state_spec" in input_spec.keys(): + input_spec["full_state_spec"] = self.transform_state_spec( + input_spec["full_state_spec"], + test_input_spec, + ) + print(input_spec) + return input_spec + def transform_output_spec(self, output_spec: Composite) -> Composite: output_spec = output_spec.clone() @@ -4586,6 +4632,16 @@ def transform_done_spec( ) -> TensorSpec: return self._transform_spec(done_spec, test_output_spec) + def transform_action_spec( + self, action_spec: TensorSpec, test_input_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(action_spec, test_input_spec) + + def transform_state_spec( + self, state_spec: TensorSpec, test_input_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(state_spec, test_input_spec) + class Hash(UnaryTransform): r"""Adds a hash value to a tensordict. @@ -4593,12 +4649,14 @@ class Hash(UnaryTransform): Args: in_keys (sequence of NestedKey): the keys of the values to hash. out_keys (sequence of NestedKey): the keys of the resulting hashes. + in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call. + out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call. + + Keyword Args: hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, the hash function must accept it as its second argument. Default is ``Hash.reproducible_hash``. seed (optional): seed to use for the hash function, if it requires one. - - Keyword Args: use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4684,9 +4742,11 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + *, hash_fn: Callable = None, seed: Any | None = None, - *, use_raw_nontensor: bool = False, ): if hash_fn is None: @@ -4697,6 +4757,8 @@ def __init__( super().__init__( in_keys=in_keys, out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, fn=self.call_hash_fn, use_raw_nontensor=use_raw_nontensor, ) @@ -4725,7 +4787,7 @@ def reproducible_hash(cls, string, seed=None): if seed is not None: seeded_string = seed + string else: - seeded_string = string + seeded_string = str(string) # Create a new SHA-256 hash object hash_object = hashlib.sha256() From 344e7ae0f161c7c78c7c3ee6c533d35ee4e092f9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 Jan 2025 09:47:11 +0000 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 2e9b1e7fa69..f1c433c3014 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4943,7 +4943,7 @@ def call_tokenizer_fn(self, value: str | List[str]): if isinstance(value, str): out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0] # TODO: incorporate attention mask - attention_mask = torch.ones_like(out, dtype=torch.bool) + # attention_mask = torch.ones_like(out, dtype=torch.bool) else: kwargs["padding"] = ( self.padding if self.max_length is None else "max_length" @@ -4951,7 +4951,7 @@ def call_tokenizer_fn(self, value: str | List[str]): # kwargs["return_attention_mask"] = False # kwargs["return_token_type_ids"] = False out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs) - attention_mask = out["attention_mask"] + # attention_mask = out["attention_mask"] out = out["input_ids"] if device is not None and out.device != device: From 5eb80863b75928ba91aa73a81147844f6ea1ce63 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 14:24:20 -0800 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- docs/source/reference/envs.rst | 75 ++++++++++++++++++++++----- test/test_transforms.py | 25 ++++----- torchrl/envs/transforms/transforms.py | 19 +++---- 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c4f3f6eda9a..a1520ca0c63 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -731,29 +731,80 @@ pixels or states etc). Forward and inverse transforms ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Transforms also have an ``inv`` method that is called before -the action is applied in reverse order over the composed transform chain: -this allows to apply transforms to data in the environment before the action is taken -in the environment. The keys to be included in this inverse transform are passed through the -``"in_keys_inv"`` keyword argument: +Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse +order over the composed transform chain. This allows applying transforms to data in the environment before the action is +taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` +keyword argument, and the out-keys default to these values in most cases: .. code-block:: :caption: Inverse transform >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step -The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part -of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the -outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform` -class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they -are part of the outside world. The transform changes these names to make them match the names of the inner, base -environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys`` -are mapped to the corresponding ``out_keys``. +The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. + +Understanding Transform Keys +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world +(e.g., your policy): + +- `in_keys` refers to the base environment's perspective (inner = `base_env` of the + :class:`~torchrl.envs.TransformedEnv`). +- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). + +For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized +observation, while the base environment outputs a regular observation. + +Similarly, for inverse keys: + +- `in_keys_inv` refers to entries as seen by the base environment. +- `out_keys_inv` refers to entries as seen or produced by the policy. + +The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input +`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The +transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. +The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding +`out_keys`. .. figure:: /_static/img/rename_transform.png Rename transform logic +Transforming Tensors and Specs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When transforming actual tensors (coming from the policy), the process is schematically represented as: + + >>> for t in reversed(self.transform): + ... td = t.inv(td) + +This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy +is properly transformed. + +For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): + + >>> def transform_action_spec(self, action_spec): + ... for t in self.transform: + ... action_spec = t.transform_action_spec(action_spec) + ... return action_spec + +A pseudocode for a single transform_action_spec could be: + + >>> def transform_action_spec(self, action_spec): + ... return spec_from_random_values(self._apply_transform(action_spec.rand())) + +This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call +`_inv_apply_transform` but `_apply_transform` on purpose! + +Exposing Specs to the Outside World +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. +For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued +tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed +environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the +transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. Cloning transforms diff --git a/test/test_transforms.py b/test/test_transforms.py index b0d8bcfe8ef..c480015bf17 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -441,8 +441,8 @@ def test_transform_rb(self, rbclass): ClipTransform( in_keys=["observation", "reward"], out_keys=["obs_clip", "reward_clip"], - in_keys_inv=["input"], - out_keys_inv=["input_clip"], + in_keys_inv=["input_clip"], + out_keys_inv=["input"], low=-0.1, high=0.1, ) @@ -2509,20 +2509,17 @@ def test_transform_rb(self, rbclass): assert ("next", "observation") in td.keys(True) def test_transform_inverse(self): + return env = CountingEnv() - env = env.append_transform( - Hash( - in_keys=[], - out_keys=[], - in_keys_inv=["action"], - out_keys_inv=["action_hash"], + with pytest.raises(TypeError): + env = env.append_transform( + Hash( + in_keys=[], + out_keys=[], + in_keys_inv=["action"], + out_keys_inv=["action_hash"], + ) ) - ) - assert "action_hash" in env.action_keys - r = env.rollout(3) - env.check_env_specs() - assert "action_hash" in r - assert isinstance(r[0]["action_hash"], torch.Tensor) class TestTokenizer(TransformBase): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index afab09d0fba..3cba7d2bd1f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -146,17 +146,23 @@ def new_fun(self, input_spec): in_keys_inv = self.in_keys_inv out_keys_inv = self.out_keys_inv for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv): + in_key = unravel_key(in_key) + out_key = unravel_key(out_key) # if in_key != out_key: # # we only change the input spec if the key is the same # continue if in_key in action_spec.keys(True, True): action_spec[out_key] = function(self, action_spec[in_key].clone()) + if in_key != out_key: + del action_spec[in_key] elif in_key in state_spec.keys(True, True): state_spec[out_key] = function(self, state_spec[in_key].clone()) + if in_key != out_key: + del state_spec[in_key] elif in_key in input_spec.keys(False, True): input_spec[out_key] = function(self, input_spec[in_key].clone()) - # else: - # raise RuntimeError(f"Couldn't find key '{in_key}' in input spec {input_spec}") + if in_key != out_key: + del input_spec[in_key] if skip: return input_spec return Composite( @@ -4857,19 +4863,14 @@ class Hash(UnaryTransform): [torchrl][INFO] check_env_specs succeeded! """ - _repertoire: Dict[Tuple[int], Any] - def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False, - repertoire: Dict[Tuple[int], Any] | None = None, ): if hash_fn is None: hash_fn = Hash.reproducible_hash @@ -4879,13 +4880,9 @@ def __init__( super().__init__( in_keys=in_keys, out_keys=out_keys, - in_keys_inv=in_keys_inv, - out_keys_inv=out_keys_inv, fn=self.call_hash_fn, use_raw_nontensor=use_raw_nontensor, ) - if in_keys_inv is not None: - self._repertoire = repertoire if repertoire is not None else {} def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: inputs = tensordict.select(*self.in_keys_inv).detach().cpu()