Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kurtamohler committed Feb 6, 2025
1 parent 4c06ce2 commit f22e0b3
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 61 deletions.
102 changes: 80 additions & 22 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from functools import partial
from sys import platform

import numpy as np

import pytest

import tensordict.tensordict
Expand Down Expand Up @@ -2288,7 +2290,7 @@ class TestHash(TransformBase):
def test_transform_no_env(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
hash_fn = hash
hash_fn = lambda x: torch.tensor(hash(x))
elif datatype == "str":
obs = "abcdefg"
hash_fn = Hash.reproducible_hash
Expand All @@ -2302,6 +2304,7 @@ def test_transform_no_env(self, datatype):
)

def fn0(x):
# return tuple([tuple(Hash.reproducible_hash(x_).tolist()) for x_ in x])
return torch.stack([Hash.reproducible_hash(x_) for x_ in x])

hash_fn = fn0
Expand Down Expand Up @@ -2334,7 +2337,7 @@ def test_single_trans_env_check(self, datatype):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
hash_fn=lambda x: torch.tensor(hash(x)),
)
base_env = CountingEnv()
elif datatype == "str":
Expand All @@ -2353,7 +2356,7 @@ def make_env():
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
hash_fn=lambda x: torch.tensor(hash(x)),
)
base_env = CountingEnv()

Expand All @@ -2376,7 +2379,7 @@ def make_env():
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
hash_fn=lambda x: torch.tensor(hash(x)),
)
base_env = CountingEnv()
elif datatype == "str":
Expand All @@ -2402,7 +2405,7 @@ def test_trans_serial_env_check(self, datatype):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]),
)
base_env = CountingEnv
elif datatype == "str":
Expand All @@ -2422,7 +2425,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]),
)
base_env = CountingEnv
elif datatype == "str":
Expand Down Expand Up @@ -2457,7 +2460,7 @@ def test_transform_compose(self, datatype):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
hash_fn=lambda x: torch.tensor(hash(x)),
)
t = Compose(t)
td_hashed = t(td)
Expand All @@ -2469,7 +2472,7 @@ def test_transform_model(self):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hashing"), ("hashing",)],
hash_fn=hash,
hash_fn=lambda x: torch.tensor(hash(x)),
)
model = nn.Sequential(t, nn.Identity())
td = TensorDict(
Expand All @@ -2486,7 +2489,7 @@ def test_transform_env(self):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
hash_fn=lambda x: torch.tensor(hash(x)),
)
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
assert env.observation_spec["hashing"]
Expand All @@ -2499,7 +2502,7 @@ def test_transform_rb(self, rbclass):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hashing"), ("hashing",)],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]),
)
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
Expand All @@ -2519,18 +2522,73 @@ def test_transform_rb(self, rbclass):
assert "observation" in td.keys()
assert ("next", "observation") in td.keys(True)

def test_transform_inverse(self):
return
env = CountingEnv()
with pytest.raises(TypeError):
env = env.append_transform(
Hash(
in_keys=[],
out_keys=[],
in_keys_inv=["action"],
out_keys_inv=["action_hash"],
)
)
@pytest.mark.parametrize("repertoire_gen", [lambda: None, lambda: {}])
def test_transform_inverse(self, repertoire_gen):
repertoire = repertoire_gen()
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
in_keys_inv=["observation"],
out_keys_inv=["hashing"],
repertoire=repertoire,
)
inputs = [
TensorDict({"observation": "test string"}),
TensorDict({"observation": torch.randn(10)}),
TensorDict({"observation": "another string"}),
TensorDict({"observation": torch.randn(3, 2, 1, 8)}),
]
outputs = [t(input.clone()).exclude("observation") for input in inputs]

# Run the inputs through again, just to make sure that using the same
# inputs doesn't overwrite the repertoire.
for input in inputs:
t(input.clone())

assert len(t._repertoire) == 4

inv_inputs = [t.inv(output.clone()) for output in outputs]

for input, inv_input in zip(inputs, inv_inputs):
if torch.is_tensor(input["observation"]):
assert (input["observation"] == inv_input["observation"]).all()
else:
assert input["observation"] == inv_input["observation"]

@pytest.mark.parametrize("repertoire_gen", [lambda: None, lambda: {}])
def test_repertoire(self, repertoire_gen):
repertoire = repertoire_gen()
t = Hash(in_keys=["observation"], out_keys=["hashing"], repertoire=repertoire)
inputs = [
"string",
["a", "b"],
torch.randn(3, 4, 1),
torch.randn(()),
torch.randn(0),
1234,
[1, 2, 3, 4],
]
outputs = []

for input in inputs:
td = TensorDict({"observation": input})
outputs.append(t(td.clone()).clone()["hashing"])

for output, input in zip(outputs, inputs):
if repertoire is not None:
stored_input = repertoire[t.hash_to_repertoire_key(output)]
assert stored_input is t.get_input_from_hash(output)

if torch.is_tensor(stored_input):
assert (stored_input == torch.as_tensor(input)).all()
elif isinstance(stored_input, np.ndarray):
assert (stored_input == np.asarray(input)).all()

else:
assert stored_input == input
else:
with pytest.raises(RuntimeError):
stored_input = t.get_input_from_hash(output)


@pytest.mark.skipif(
Expand Down
100 changes: 61 additions & 39 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4826,25 +4826,25 @@ class Hash(UnaryTransform):
in_keys (sequence of NestedKey): the keys of the values to hash.
out_keys (sequence of NestedKey): the keys of the resulting hashes.
in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
.. note:: If an inverse map is required, a repertoire ``Dict[Tuple[int], Any]`` of hash to value should be
passed alongside the list of keys to let the ``Hash`` transform know how to recover a value from a
given hash. This repertoire isn't copied, so it can be modified in the same workspace after the
transform instantiation and these modifications will be reflected in the map. Missing hashes will be
mapped to ``None``.
out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
Keyword Args:
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
the hash function must accept it as its second argument. Default is
``Hash.reproducible_hash``.
hash_fn (Callable, optional): the hash function to use. The function
signature must be
``(input: Any, seed: Any | None) -> torch.Tensor``.
``seed`` is only used if this transform is initialized with the
``seed`` argument. Default is ``Hash.reproducible_hash``.
seed (optional): seed to use for the hash function, if it requires one.
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
inputs are given directly to ``fn``, which must support those
inputs. Default is ``False``.
repertoire (Dict[Tuple[int], Any], optional): If given, this dict stores
the inverse mappings from hashes to inputs. This repertoire isn't
copied, so it can be modified in the same workspace after the
transform instantiation and these modifications will be reflected in
the map. Missing hashes will be mapped to ``None``. Default: ``None``
>>> from torchrl.envs import GymEnv, UnaryTransform, Hash
>>> env = GymEnv("Pendulum-v1")
Expand Down Expand Up @@ -4925,57 +4925,79 @@ def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
in_keys_inv: Sequence[NestedKey] = None,
out_keys_inv: Sequence[NestedKey] = None,
*,
hash_fn: Callable = None,
seed: Any | None = None,
use_raw_nontensor: bool = False,
repertoire: Tuple[Tuple[int], Any] = None,
):
if hash_fn is None:
hash_fn = Hash.reproducible_hash

if repertoire is None and in_keys_inv is not None and len(in_keys_inv) > 0:
self._repertoire = {}
else:
self._repertoire = repertoire

self._seed = seed
self._hash_fn = hash_fn
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
fn=self.call_hash_fn,
inv_fn=self.get_input_from_hash,
use_raw_nontensor=use_raw_nontensor,
)

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
inputs = tensordict.select(*self.in_keys_inv).detach().cpu()
tensordict = super()._inv_call(tensordict)

def register_outcome(td):
# We need to treat each hash independently
if td.ndim:
if td.ndim > 1:
td_r = td.reshape(-1)
elif td.ndim == 1:
td_r = td
result = torch.stack([register_outcome(_td) for _td in td_r.unbind(0)])
if td_r is not td:
return result.reshape(td.shape)
return result
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
inp = inputs.get(in_key)
inp = tuple(inp.tolist())
outp = self._repertoire.get(inp)
td[out_key] = outp
return td

return register_outcome(tensordict)

def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if self.in_keys_inv is not None:
return {"_repertoire": self._repertoire}
return {}
return {"_repertoire": self._repertoire}

@classmethod
def hash_to_repertoire_key(cls, hash_tensor):
if isinstance(hash_tensor, torch.Tensor):
if hash_tensor.dim() == 0:
return hash_tensor.tolist()
return tuple(cls.hash_to_repertoire_key(t) for t in hash_tensor.tolist())
elif isinstance(hash_tensor, list):
return tuple(cls.hash_to_repertoire_key(t) for t in hash_tensor)
else:
return hash_tensor

def get_input_from_hash(self, hash_tensor):
"""Look up the input that was given for a particular hash output.
This feature is only available if, during initialization, either the
:arg:`repertoire` argument was given or both the :arg:`in_keys_inv` and
:arg:`out_keys_inv` arguments were given.
Args:
hash_tensor (Tensor): The hash output.
Returns:
Any: The input that the hash was generated from.
"""
if self._repertoire is None:
raise RuntimeError(
"An inverse transform was queried but the repertoire is None."
)
return self._repertoire[self.hash_to_repertoire_key(hash_tensor)]

def call_hash_fn(self, value):
if self._seed is None:
return self._hash_fn(value)
hash_tensor = self._hash_fn(value)
else:
return self._hash_fn(value, self._seed)
hash_tensor = self._hash_fn(value, self._seed)
if not torch.is_tensor(hash_tensor):
raise ValueError(
f"Hash function must return a tensor, but got {type(hash_tensor)}"
)
if self._repertoire is not None:
self._repertoire[self.hash_to_repertoire_key(hash_tensor)] = copy(value)
return hash_tensor

@classmethod
def reproducible_hash(cls, string, seed=None):
Expand Down

0 comments on commit f22e0b3

Please sign in to comment.