Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kurtamohler committed Feb 6, 2025
1 parent f22e0b3 commit 945197b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 11 deletions.
68 changes: 68 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3709,6 +3709,74 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
if include_san:
assert "san_hash" in env.observation_spec.keys()

# Test that `include_hash_inv=True` allows us to specify the board state
# with just the "fen_hash" or "pgn_hash", not "fen" or "pgn", when taking a
# step in the env.
@pytest.mark.parametrize(
"include_fen,include_pgn",
[[True, False], [False, True]],
)
@pytest.mark.parametrize("stateful", [True, False])
def test_env_hash_inv(self, include_fen, include_pgn, stateful):
env = ChessEnv(
include_fen=include_fen,
include_pgn=include_pgn,
include_hash=True,
include_hash_inv=True,
stateful=stateful,
)
env.check_env_specs()

def exclude_fen_and_pgn(td):
td = td.exclude("fen")
td = td.exclude("pgn")
return td

td0 = env.reset()

if include_fen:
env_check_fen = ChessEnv(
include_fen=True,
stateful=stateful,
)

if include_pgn:
env_check_pgn = ChessEnv(
include_pgn=True,
stateful=stateful,
)

for _ in range(8):
td1 = env.rand_step(exclude_fen_and_pgn(td0.clone()))

# Confirm that fen/pgn was not used to determine the board state
assert "fen" not in td1.keys()
assert "pgn" not in td1.keys()

if include_fen:
assert (td1["fen_hash"] == td0["fen_hash"]).all()
assert "fen" in td1["next"]

# Check that if we start in the same board state and perform the
# same action in an env that does not use hashes, we obtain the
# same next board state. This confirms that we really can
# successfully specify the board state with a hash.
td0_check = td1.clone().exclude("next").update({"fen": td0["fen"]})
assert (
env_check_fen.step(td0_check)["next", "fen"] == td1["next", "fen"]
)

if include_pgn:
assert (td1["pgn_hash"] == td0["pgn_hash"]).all()
assert "pgn" in td1["next"]

td0_check = td1.clone().exclude("next").update({"pgn": td0["pgn"]})
assert (
env_check_pgn.step(td0_check)["next", "pgn"] == td1["next", "pgn"]
)

td0 = td1["next"]

@pytest.mark.skipif(not _has_tv, reason="torchvision not found.")
@pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.")
@pytest.mark.parametrize("stateful", [False, True])
Expand Down
40 changes: 29 additions & 11 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,38 @@
class _ChessMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if kwargs.get("include_hash"):
include_hash = kwargs.get("include_hash")
include_hash_inv = kwargs.get("include_hash_inv")
if include_hash:
from torchrl.envs import Hash

in_keys = []
out_keys = []
if instance.include_san:
in_keys.append("san")
out_keys.append("san_hash")
if instance.include_fen:
in_keys.append("fen")
out_keys.append("fen_hash")
if instance.include_pgn:
in_keys.append("pgn")
out_keys.append("pgn_hash")
instance = instance.append_transform(Hash(in_keys, out_keys))
in_keys_inv = [] if include_hash_inv else None
out_keys_inv = [] if include_hash_inv else None

def maybe_add_keys(condition, in_key, out_key):
if condition:
in_keys.append(in_key)
out_keys.append(out_key)
if include_hash_inv:
in_keys_inv.append(in_key)
out_keys_inv.append(out_key)

maybe_add_keys(instance.include_san, "san", "san_hash")
maybe_add_keys(instance.include_fen, "fen", "fen_hash")
maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash")

instance = instance.append_transform(
Hash(in_keys, out_keys, in_keys_inv, out_keys_inv)
)
elif include_hash_inv:
raise ValueError(
(
"'include_hash_inv=True' can only be set if"
f"'include_hash=True', but got 'include_hash={include_hash}'."
)
)
if kwargs.get("mask_actions", True):
from torchrl.envs import ActionMask

Expand Down Expand Up @@ -265,6 +282,7 @@ def __init__(
include_pgn: bool = False,
include_legal_moves: bool = False,
include_hash: bool = False,
include_hash_inv: bool = False,
mask_actions: bool = True,
pixels: bool = False,
):
Expand Down

0 comments on commit 945197b

Please sign in to comment.