From f22e0b396c8606f40d02704b6ee6dafbea59b827 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 5 Feb 2025 16:36:49 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- test/test_transforms.py | 102 ++++++++++++++++++++------ torchrl/envs/transforms/transforms.py | 100 +++++++++++++++---------- 2 files changed, 141 insertions(+), 61 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 038f284ed19..ff910434c5d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -16,6 +16,8 @@ from functools import partial from sys import platform +import numpy as np + import pytest import tensordict.tensordict @@ -2288,7 +2290,7 @@ class TestHash(TransformBase): def test_transform_no_env(self, datatype): if datatype == "tensor": obs = torch.tensor(10) - hash_fn = hash + hash_fn = lambda x: torch.tensor(hash(x)) elif datatype == "str": obs = "abcdefg" hash_fn = Hash.reproducible_hash @@ -2302,6 +2304,7 @@ def test_transform_no_env(self, datatype): ) def fn0(x): + # return tuple([tuple(Hash.reproducible_hash(x_).tolist()) for x_ in x]) return torch.stack([Hash.reproducible_hash(x_) for x_ in x]) hash_fn = fn0 @@ -2334,7 +2337,7 @@ def test_single_trans_env_check(self, datatype): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=hash, + hash_fn=lambda x: torch.tensor(hash(x)), ) base_env = CountingEnv() elif datatype == "str": @@ -2353,7 +2356,7 @@ def make_env(): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=hash, + hash_fn=lambda x: torch.tensor(hash(x)), ) base_env = CountingEnv() @@ -2376,7 +2379,7 @@ def make_env(): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=hash, + hash_fn=lambda x: torch.tensor(hash(x)), ) base_env = CountingEnv() elif datatype == "str": @@ -2402,7 +2405,7 @@ def test_trans_serial_env_check(self, datatype): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=lambda x: [hash(x[0]), hash(x[1])], + hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]), ) base_env = CountingEnv elif datatype == "str": @@ -2422,7 +2425,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=lambda x: [hash(x[0]), hash(x[1])], + hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]), ) base_env = CountingEnv elif datatype == "str": @@ -2457,7 +2460,7 @@ def test_transform_compose(self, datatype): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=hash, + hash_fn=lambda x: torch.tensor(hash(x)), ) t = Compose(t) td_hashed = t(td) @@ -2469,7 +2472,7 @@ def test_transform_model(self): t = Hash( in_keys=[("next", "observation"), ("observation",)], out_keys=[("next", "hashing"), ("hashing",)], - hash_fn=hash, + hash_fn=lambda x: torch.tensor(hash(x)), ) model = nn.Sequential(t, nn.Identity()) td = TensorDict( @@ -2486,7 +2489,7 @@ def test_transform_env(self): t = Hash( in_keys=["observation"], out_keys=["hashing"], - hash_fn=hash, + hash_fn=lambda x: torch.tensor(hash(x)), ) env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) assert env.observation_spec["hashing"] @@ -2499,7 +2502,7 @@ def test_transform_rb(self, rbclass): t = Hash( in_keys=[("next", "observation"), ("observation",)], out_keys=[("next", "hashing"), ("hashing",)], - hash_fn=lambda x: [hash(x[0]), hash(x[1])], + hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]), ) rb = rbclass(storage=LazyTensorStorage(10)) rb.append_transform(t) @@ -2519,18 +2522,73 @@ def test_transform_rb(self, rbclass): assert "observation" in td.keys() assert ("next", "observation") in td.keys(True) - def test_transform_inverse(self): - return - env = CountingEnv() - with pytest.raises(TypeError): - env = env.append_transform( - Hash( - in_keys=[], - out_keys=[], - in_keys_inv=["action"], - out_keys_inv=["action_hash"], - ) - ) + @pytest.mark.parametrize("repertoire_gen", [lambda: None, lambda: {}]) + def test_transform_inverse(self, repertoire_gen): + repertoire = repertoire_gen() + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + in_keys_inv=["observation"], + out_keys_inv=["hashing"], + repertoire=repertoire, + ) + inputs = [ + TensorDict({"observation": "test string"}), + TensorDict({"observation": torch.randn(10)}), + TensorDict({"observation": "another string"}), + TensorDict({"observation": torch.randn(3, 2, 1, 8)}), + ] + outputs = [t(input.clone()).exclude("observation") for input in inputs] + + # Run the inputs through again, just to make sure that using the same + # inputs doesn't overwrite the repertoire. + for input in inputs: + t(input.clone()) + + assert len(t._repertoire) == 4 + + inv_inputs = [t.inv(output.clone()) for output in outputs] + + for input, inv_input in zip(inputs, inv_inputs): + if torch.is_tensor(input["observation"]): + assert (input["observation"] == inv_input["observation"]).all() + else: + assert input["observation"] == inv_input["observation"] + + @pytest.mark.parametrize("repertoire_gen", [lambda: None, lambda: {}]) + def test_repertoire(self, repertoire_gen): + repertoire = repertoire_gen() + t = Hash(in_keys=["observation"], out_keys=["hashing"], repertoire=repertoire) + inputs = [ + "string", + ["a", "b"], + torch.randn(3, 4, 1), + torch.randn(()), + torch.randn(0), + 1234, + [1, 2, 3, 4], + ] + outputs = [] + + for input in inputs: + td = TensorDict({"observation": input}) + outputs.append(t(td.clone()).clone()["hashing"]) + + for output, input in zip(outputs, inputs): + if repertoire is not None: + stored_input = repertoire[t.hash_to_repertoire_key(output)] + assert stored_input is t.get_input_from_hash(output) + + if torch.is_tensor(stored_input): + assert (stored_input == torch.as_tensor(input)).all() + elif isinstance(stored_input, np.ndarray): + assert (stored_input == np.asarray(input)).all() + + else: + assert stored_input == input + else: + with pytest.raises(RuntimeError): + stored_input = t.get_input_from_hash(output) @pytest.mark.skipif( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 9d36248f624..883c3030ba5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4826,25 +4826,25 @@ class Hash(UnaryTransform): in_keys (sequence of NestedKey): the keys of the values to hash. out_keys (sequence of NestedKey): the keys of the resulting hashes. in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call. - - .. note:: If an inverse map is required, a repertoire ``Dict[Tuple[int], Any]`` of hash to value should be - passed alongside the list of keys to let the ``Hash`` transform know how to recover a value from a - given hash. This repertoire isn't copied, so it can be modified in the same workspace after the - transform instantiation and these modifications will be reflected in the map. Missing hashes will be - mapped to ``None``. - out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call. Keyword Args: - hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, - the hash function must accept it as its second argument. Default is - ``Hash.reproducible_hash``. + hash_fn (Callable, optional): the hash function to use. The function + signature must be + ``(input: Any, seed: Any | None) -> torch.Tensor``. + ``seed`` is only used if this transform is initialized with the + ``seed`` argument. Default is ``Hash.reproducible_hash``. seed (optional): seed to use for the hash function, if it requires one. use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs are given directly to ``fn``, which must support those inputs. Default is ``False``. + repertoire (Dict[Tuple[int], Any], optional): If given, this dict stores + the inverse mappings from hashes to inputs. This repertoire isn't + copied, so it can be modified in the same workspace after the + transform instantiation and these modifications will be reflected in + the map. Missing hashes will be mapped to ``None``. Default: ``None`` >>> from torchrl.envs import GymEnv, UnaryTransform, Hash >>> env = GymEnv("Pendulum-v1") @@ -4925,57 +4925,79 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], + in_keys_inv: Sequence[NestedKey] = None, + out_keys_inv: Sequence[NestedKey] = None, *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False, + repertoire: Tuple[Tuple[int], Any] = None, ): if hash_fn is None: hash_fn = Hash.reproducible_hash + if repertoire is None and in_keys_inv is not None and len(in_keys_inv) > 0: + self._repertoire = {} + else: + self._repertoire = repertoire + self._seed = seed self._hash_fn = hash_fn super().__init__( in_keys=in_keys, out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, fn=self.call_hash_fn, + inv_fn=self.get_input_from_hash, use_raw_nontensor=use_raw_nontensor, ) - def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: - inputs = tensordict.select(*self.in_keys_inv).detach().cpu() - tensordict = super()._inv_call(tensordict) - - def register_outcome(td): - # We need to treat each hash independently - if td.ndim: - if td.ndim > 1: - td_r = td.reshape(-1) - elif td.ndim == 1: - td_r = td - result = torch.stack([register_outcome(_td) for _td in td_r.unbind(0)]) - if td_r is not td: - return result.reshape(td.shape) - return result - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - inp = inputs.get(in_key) - inp = tuple(inp.tolist()) - outp = self._repertoire.get(inp) - td[out_key] = outp - return td - - return register_outcome(tensordict) - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - if self.in_keys_inv is not None: - return {"_repertoire": self._repertoire} - return {} + return {"_repertoire": self._repertoire} + + @classmethod + def hash_to_repertoire_key(cls, hash_tensor): + if isinstance(hash_tensor, torch.Tensor): + if hash_tensor.dim() == 0: + return hash_tensor.tolist() + return tuple(cls.hash_to_repertoire_key(t) for t in hash_tensor.tolist()) + elif isinstance(hash_tensor, list): + return tuple(cls.hash_to_repertoire_key(t) for t in hash_tensor) + else: + return hash_tensor + + def get_input_from_hash(self, hash_tensor): + """Look up the input that was given for a particular hash output. + + This feature is only available if, during initialization, either the + :arg:`repertoire` argument was given or both the :arg:`in_keys_inv` and + :arg:`out_keys_inv` arguments were given. + + Args: + hash_tensor (Tensor): The hash output. + + Returns: + Any: The input that the hash was generated from. + """ + if self._repertoire is None: + raise RuntimeError( + "An inverse transform was queried but the repertoire is None." + ) + return self._repertoire[self.hash_to_repertoire_key(hash_tensor)] def call_hash_fn(self, value): if self._seed is None: - return self._hash_fn(value) + hash_tensor = self._hash_fn(value) else: - return self._hash_fn(value, self._seed) + hash_tensor = self._hash_fn(value, self._seed) + if not torch.is_tensor(hash_tensor): + raise ValueError( + f"Hash function must return a tensor, but got {type(hash_tensor)}" + ) + if self._repertoire is not None: + self._repertoire[self.hash_to_repertoire_key(hash_tensor)] = copy(value) + return hash_tensor @classmethod def reproducible_hash(cls, string, seed=None):