Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 27, 2025
2 parents 4102861 + 4791529 commit 4c6f563
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .github/unittest/linux_libs/scripts_chess/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ dependencies:
- scipy
- hydra-core
- chess
- transformers
- cairosvg
8 changes: 4 additions & 4 deletions .github/unittest/linux_libs/scripts_chess/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with cu121"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
else
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
fi
else
printf "Failed to install pytorch"
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_minari/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- minari[gcs,hdf5]
- minari[gcs,hdf5,hf]
- gymnasium<1.0.0
12 changes: 11 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@
mp_ctx = "fork"

_has_chess = importlib.util.find_spec("chess") is not None

_has_tv = importlib.util.find_spec("torchvision") is not None
_has_cairosvg = importlib.util.find_spec("cairosvg") is not None
## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between
## the serial and parallel batched envs
# def _make_atari_env(atari_env):
Expand Down Expand Up @@ -3471,6 +3472,15 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
if include_san:
assert "san_hash" in env.observation_spec.keys()

@pytest.mark.skipif(not _has_tv, reason="torchvision not found.")
@pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.")
@pytest.mark.parametrize("stateful", [False, True])
def test_chess_rendering(self, stateful):
env = ChessEnv(stateful=stateful, include_fen=True, pixels=True)
env.check_env_specs()
r = env.rollout(3)
assert "pixels" in r

def test_pgn_bijectivity(self):
np.random.seed(0)
pgn = ChessEnv._PGN_RESTART
Expand Down
17 changes: 11 additions & 6 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import importlib.util
import urllib.error

from gym.core import ObsType

_has_isaac = importlib.util.find_spec("isaacgym") is not None

if _has_isaac:
Expand All @@ -25,7 +23,7 @@
from contextlib import nullcontext
from pathlib import Path
from sys import platform
from typing import Optional, Tuple, Union
from typing import Optional, Union
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -638,7 +636,8 @@ def test_torchrl_to_gym(self, backend, numpy):

@implement_for("gym", None, "0.26")
def test_gym_dict_action_space(self):
pytest.skip("tested for gym > 0.26 - no backward issue")
torchrl_logger.info("tested for gym > 0.26 - no backward issue")
return

@implement_for("gym", "0.26", None)
def test_gym_dict_action_space(self): # noqa: F811
Expand All @@ -653,14 +652,17 @@ def __init__(self):
self.observation_space = gym.spaces.Box(-1, 1)

def step(self, action):
assert isinstance(action, dict)
assert isinstance(action["a0"], np.ndarray)
assert isinstance(action["a1"], np.ndarray)
return (0.5, 0.0, False, False, {})

def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
):
return (0.0, {})

env = CompositeActionEnv()
Expand All @@ -686,14 +688,17 @@ def __init__(self):
self.observation_space = gym.spaces.Box(-1, 1)

def step(self, action):
assert isinstance(action, dict)
assert isinstance(action["a0"], np.ndarray)
assert isinstance(action["a1"], np.ndarray)
return (0.5, 0.0, False, False, {})

def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
):
return (0.0, {})

env = CompositeActionEnv()
Expand Down
9 changes: 6 additions & 3 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Dict

import torch
from PIL import Image
from tensordict import TensorDict, TensorDictBase
from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded

Expand Down Expand Up @@ -315,7 +314,9 @@ def __init__(
raise ImportError(
"Please install torchvision to use this environment with pixel rendering."
)
self.full_observation_spec["pixels"] = Unbounded(shape=())
self.full_observation_spec["pixels"] = Unbounded(
shape=(3, 390, 390), dtype=torch.uint8
)

self.full_action_spec = Composite(
action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64)
Expand Down Expand Up @@ -428,6 +429,8 @@ def _torchvision(cls):
@classmethod
def _get_tensor_image(cls, board):
try:
from PIL import Image

svg = board._repr_svg_()
# Convert SVG to PNG using cairosvg
png_data = io.BytesIO()
Expand All @@ -438,7 +441,7 @@ def _get_tensor_image(cls, board):
img = cls._torchvision.transforms.functional.pil_to_tensor(img)
except ImportError:
raise ImportError(
"Chess rendering requires cairosvg and torchvision to be installed."
"Chess rendering requires cairosvg, PIL and torchvision to be installed."
)
return img

Expand Down

0 comments on commit 4c6f563

Please sign in to comment.