Skip to content

Commit

Permalink
[BugFix] Apply inverse transform to input of TransformedEnv._reset
Browse files Browse the repository at this point in the history
ghstack-source-id: 5f7c1fbd19b716f2b1602c34cf2ae1362f7bc7f6
Pull Request resolved: #2787
  • Loading branch information
kurtamohler committed Feb 13, 2025
1 parent ab76027 commit 1ed5d29
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
29 changes: 29 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4058,6 +4058,35 @@ def test_chess_tokenized(self):
assert "fen" in ftd["next"]
env.check_env_specs()

@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.parametrize("include_san", [False, True])
def test_env_reset_with_hash(self, stateful, include_san):
env = ChessEnv(
include_fen=True,
include_hash=True,
include_hash_inv=True,
stateful=stateful,
include_san=include_san,
)
cases = [
# (fen, num_legal_moves)
("5R1k/8/8/8/6R1/8/8/5K2 b - - 0 1", 1),
("8/8/2kq4/4K3/1R3Q2/8/8/8 w - - 0 1", 2),
("6R1/8/8/4rq2/3pPk2/5n2/8/2B1R2K b - e3 0 1", 2),
]
for fen, num_legal_moves in cases:
# Load the state by fen.
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
assert td["action_mask"].sum() == num_legal_moves
# Reset to initial state just to make sure that the next reset
# actually changes the state.
assert env.reset()["action_mask"].sum() == 20
# Load the state by fen hash and make sure it gives the same output
# as before.
td_check = env.reset(td.select("fen_hash"))
assert (td_check == td).all()


class TestCustomEnvs:
def test_tictactoe_env(self):
Expand Down
4 changes: 4 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,10 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
tensordict = tensordict.select(
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
)
# Inputs might be transformed, so need to apply inverse transform
# before passing to the env reset function.
with _set_missing_tolerance(self.transform, True):
tensordict = self.transform.inv(tensordict)
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
if tensordict is None:
# make sure all transforms see a source tensordict
Expand Down

0 comments on commit 1ed5d29

Please sign in to comment.