From 945197b895f88e475c072d45848622fdad6d0e1f Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 5 Feb 2025 16:36:49 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_env.py | 68 ++++++++++++++++++++++++++++++++++++ torchrl/envs/custom/chess.py | 40 +++++++++++++++------ 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index f45aa7c4668..dc28a809115 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3709,6 +3709,74 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san if include_san: assert "san_hash" in env.observation_spec.keys() + # Test that `include_hash_inv=True` allows us to specify the board state + # with just the "fen_hash" or "pgn_hash", not "fen" or "pgn", when taking a + # step in the env. + @pytest.mark.parametrize( + "include_fen,include_pgn", + [[True, False], [False, True]], + ) + @pytest.mark.parametrize("stateful", [True, False]) + def test_env_hash_inv(self, include_fen, include_pgn, stateful): + env = ChessEnv( + include_fen=include_fen, + include_pgn=include_pgn, + include_hash=True, + include_hash_inv=True, + stateful=stateful, + ) + env.check_env_specs() + + def exclude_fen_and_pgn(td): + td = td.exclude("fen") + td = td.exclude("pgn") + return td + + td0 = env.reset() + + if include_fen: + env_check_fen = ChessEnv( + include_fen=True, + stateful=stateful, + ) + + if include_pgn: + env_check_pgn = ChessEnv( + include_pgn=True, + stateful=stateful, + ) + + for _ in range(8): + td1 = env.rand_step(exclude_fen_and_pgn(td0.clone())) + + # Confirm that fen/pgn was not used to determine the board state + assert "fen" not in td1.keys() + assert "pgn" not in td1.keys() + + if include_fen: + assert (td1["fen_hash"] == td0["fen_hash"]).all() + assert "fen" in td1["next"] + + # Check that if we start in the same board state and perform the + # same action in an env that does not use hashes, we obtain the + # same next board state. This confirms that we really can + # successfully specify the board state with a hash. + td0_check = td1.clone().exclude("next").update({"fen": td0["fen"]}) + assert ( + env_check_fen.step(td0_check)["next", "fen"] == td1["next", "fen"] + ) + + if include_pgn: + assert (td1["pgn_hash"] == td0["pgn_hash"]).all() + assert "pgn" in td1["next"] + + td0_check = td1.clone().exclude("next").update({"pgn": td0["pgn"]}) + assert ( + env_check_pgn.step(td0_check)["next", "pgn"] == td1["next", "pgn"] + ) + + td0 = td1["next"] + @pytest.mark.skipif(not _has_tv, reason="torchvision not found.") @pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.") @pytest.mark.parametrize("stateful", [False, True]) diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index d5b744cfc84..b949af6ca0a 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -22,21 +22,38 @@ class _ChessMeta(_EnvPostInit): def __call__(cls, *args, **kwargs): instance = super().__call__(*args, **kwargs) - if kwargs.get("include_hash"): + include_hash = kwargs.get("include_hash") + include_hash_inv = kwargs.get("include_hash_inv") + if include_hash: from torchrl.envs import Hash in_keys = [] out_keys = [] - if instance.include_san: - in_keys.append("san") - out_keys.append("san_hash") - if instance.include_fen: - in_keys.append("fen") - out_keys.append("fen_hash") - if instance.include_pgn: - in_keys.append("pgn") - out_keys.append("pgn_hash") - instance = instance.append_transform(Hash(in_keys, out_keys)) + in_keys_inv = [] if include_hash_inv else None + out_keys_inv = [] if include_hash_inv else None + + def maybe_add_keys(condition, in_key, out_key): + if condition: + in_keys.append(in_key) + out_keys.append(out_key) + if include_hash_inv: + in_keys_inv.append(in_key) + out_keys_inv.append(out_key) + + maybe_add_keys(instance.include_san, "san", "san_hash") + maybe_add_keys(instance.include_fen, "fen", "fen_hash") + maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash") + + instance = instance.append_transform( + Hash(in_keys, out_keys, in_keys_inv, out_keys_inv) + ) + elif include_hash_inv: + raise ValueError( + ( + "'include_hash_inv=True' can only be set if" + f"'include_hash=True', but got 'include_hash={include_hash}'." + ) + ) if kwargs.get("mask_actions", True): from torchrl.envs import ActionMask @@ -265,6 +282,7 @@ def __init__( include_pgn: bool = False, include_legal_moves: bool = False, include_hash: bool = False, + include_hash_inv: bool = False, mask_actions: bool = True, pixels: bool = False, ):