Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify rendering to default to the dm-control values #92

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 15 additions & 35 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ def __init__(
self,
env: composer.Environment | control.Environment | dm_env.Environment,
render_mode: str | None = None,
render_height: int = 84,
render_width: int = 84,
camera_id: int | str = 0,
render_scene_callback: (Callable[[MujocoEnginePhysics, MjvScene], None])
| None = None,
render_kwargs: dict[str, Any] | None = None,
):
"""Initialises the environment with a render mode along with render information.
Expand All @@ -75,20 +70,12 @@ def __init__(
Args:
env (Optional[composer.Environment | control.Environment | dm_env.Environment]): DM Control env to wrap
render_mode (Optional[str]): rendering mode (options: "human", "rgb_array", "depth_array", "multi_camera")
render_height (Optional[int]): height for rendering frame in pixels
render_width (Optional[int]): width for rendering frame in pixels
camera_id (Optional[int | str]): Optional camera name or index. Defaults to -1, the free
camera, which is always defined. A non-negative integer or string
corresponds to a fixed camera, which must be defined in the model XML.
If `camera_id` is a string then the camera must also be named.
render_scene_callback (Optional[(Callable[[MujocoEnginePhysics, mujoco.MjvScene], None])]): Called after
the scene has been created and before it is rendered. Can be used to add more geoms to the scene.
render_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for rendering. Note: kwargs are not used
for human rendering, which uses simpler Gymnasium MuJoCo rendering.
render_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for rendering.
For the width, height and camera id use "width", "height" and "camera_id" respectively.
See the dm_control implementation for the list of possible kwargs, https://github.com/deepmind/dm_control/blob/330c91f41a21eacadcf8316f0a071327e3f5c017/dm_control/mujoco/engine.py#L178
Note: kwargs are not used for human rendering, which uses simpler Gymnasium MuJoCo rendering.
"""
EzPickle.__init__(
self, env, render_mode, render_height, render_width, camera_id
)
EzPickle.__init__(self, env, render_mode, render_kwargs)
self._env: Any = env
self.env_type = self._find_env_type(env)
self.metadata["render_fps"] = self._env.control_timestep() * 1000
Expand All @@ -98,9 +85,6 @@ def __init__(

assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.render_height, self.render_width = render_height, render_width
self.camera_id = camera_id
self.render_scene_callback = render_scene_callback

if render_kwargs is None:
render_kwargs = {}
Expand Down Expand Up @@ -153,45 +137,41 @@ def render(self) -> np.ndarray | None:
"""Renders the dm-control env."""
if self.render_mode == "rgb_array":
return self._env.physics.render(
height=self.render_height,
width=self.render_width,
camera_id=self.camera_id,
scene_callback=self.render_scene_callback,
**self.render_kwargs,
)
elif self.render_mode == "depth_array":
return self._env.physics.render(
height=self.render_height,
width=self.render_width,
camera_id=self.camera_id,
depth=True,
scene_callback=self.render_scene_callback,
**self.render_kwargs,
)
elif self.render_mode == "multi_camera":
physics = self._env.physics
num_cameras = physics.model.ncam
num_columns = int(math.ceil(math.sqrt(num_cameras)))
num_rows = int(math.ceil(float(num_cameras) / num_columns))

# 240 and 320 are the default values in dm-control
height = self.render_kwargs.get("height", 240)
width = self.render_kwargs.get("width", 320)
frame = np.zeros(
(num_rows * self.render_height, num_columns * self.render_width, 3),
(num_rows * height, num_columns * width, 3),
dtype=np.uint8,
)
assert (
"camera_id" not in self.render_kwargs
), "The camera_id is specified in `multi_camera` render so don't include it in the render_kwargs"
for col in range(num_columns):
for row in range(num_rows):
camera_id = row * num_columns + col
if camera_id >= num_cameras:
break
subframe = physics.render(
height=self.render_height,
width=self.render_width,
camera_id=camera_id,
scene_callback=self.render_scene_callback,
**self.render_kwargs,
)
frame[
row * self.render_height : (row + 1) * self.render_height,
col * self.render_width : (col + 1) * self.render_width,
row * height : (row + 1) * height,
col * width : (col + 1) * width,
] = subframe
return frame

Expand Down
22 changes: 14 additions & 8 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def test_pickle(env_id):
env_2.close()


@pytest.mark.parametrize("camera_id", [0, 1])
@pytest.mark.parametrize("camera_id", [-1, 0, 1])
def test_rendering_camera_id(camera_id):
"""Test that dm-control rendering works."""
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="rgb_array",
camera_id=camera_id,
render_kwargs=dict(camera_id=camera_id),
)
env.reset()
frames = []
Expand All @@ -157,8 +157,10 @@ def test_rendering_multiple_cameras(height, width):
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="multi_camera",
render_height=height,
render_width=width,
render_kwargs=dict(
height=height,
width=width,
),
)
env.reset()
frames = []
Expand All @@ -175,8 +177,10 @@ def test_rendering_depth(height, width):
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="depth_array",
render_height=height,
render_width=width,
render_kwargs=dict(
height=height,
width=width,
),
)
env.reset()
frames = []
Expand All @@ -193,8 +197,10 @@ def test_render_height_widths(height, width):
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="rgb_array",
render_height=height,
render_width=width,
render_kwargs=dict(
height=height,
width=width,
),
)
env.reset()
frame = env.render()
Expand Down