Skip to content

Commit

Permalink
Modify rendering to default to the dm-control values (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Jun 15, 2023
1 parent 464d6be commit 69bcfbf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 43 deletions.
50 changes: 15 additions & 35 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ def __init__(
self,
env: composer.Environment | control.Environment | dm_env.Environment,
render_mode: str | None = None,
render_height: int = 84,
render_width: int = 84,
camera_id: int | str = 0,
render_scene_callback: (Callable[[MujocoEnginePhysics, MjvScene], None])
| None = None,
render_kwargs: dict[str, Any] | None = None,
):
"""Initialises the environment with a render mode along with render information.
Expand All @@ -75,20 +70,12 @@ def __init__(
Args:
env (Optional[composer.Environment | control.Environment | dm_env.Environment]): DM Control env to wrap
render_mode (Optional[str]): rendering mode (options: "human", "rgb_array", "depth_array", "multi_camera")
render_height (Optional[int]): height for rendering frame in pixels
render_width (Optional[int]): width for rendering frame in pixels
camera_id (Optional[int | str]): Optional camera name or index. Defaults to -1, the free
camera, which is always defined. A non-negative integer or string
corresponds to a fixed camera, which must be defined in the model XML.
If `camera_id` is a string then the camera must also be named.
render_scene_callback (Optional[(Callable[[MujocoEnginePhysics, mujoco.MjvScene], None])]): Called after
the scene has been created and before it is rendered. Can be used to add more geoms to the scene.
render_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for rendering. Note: kwargs are not used
for human rendering, which uses simpler Gymnasium MuJoCo rendering.
render_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for rendering.
For the width, height and camera id use "width", "height" and "camera_id" respectively.
See the dm_control implementation for the list of possible kwargs, https://github.com/deepmind/dm_control/blob/330c91f41a21eacadcf8316f0a071327e3f5c017/dm_control/mujoco/engine.py#L178
Note: kwargs are not used for human rendering, which uses simpler Gymnasium MuJoCo rendering.
"""
EzPickle.__init__(
self, env, render_mode, render_height, render_width, camera_id
)
EzPickle.__init__(self, env, render_mode, render_kwargs)
self._env: Any = env
self.env_type = self._find_env_type(env)
self.metadata["render_fps"] = self._env.control_timestep() * 1000
Expand All @@ -98,9 +85,6 @@ def __init__(

assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.render_height, self.render_width = render_height, render_width
self.camera_id = camera_id
self.render_scene_callback = render_scene_callback

if render_kwargs is None:
render_kwargs = {}
Expand Down Expand Up @@ -153,45 +137,41 @@ def render(self) -> np.ndarray | None:
"""Renders the dm-control env."""
if self.render_mode == "rgb_array":
return self._env.physics.render(
height=self.render_height,
width=self.render_width,
camera_id=self.camera_id,
scene_callback=self.render_scene_callback,
**self.render_kwargs,
)
elif self.render_mode == "depth_array":
return self._env.physics.render(
height=self.render_height,
width=self.render_width,
camera_id=self.camera_id,
depth=True,
scene_callback=self.render_scene_callback,
**self.render_kwargs,
)
elif self.render_mode == "multi_camera":
physics = self._env.physics
num_cameras = physics.model.ncam
num_columns = int(math.ceil(math.sqrt(num_cameras)))
num_rows = int(math.ceil(float(num_cameras) / num_columns))

# 240 and 320 are the default values in dm-control
height = self.render_kwargs.get("height", 240)
width = self.render_kwargs.get("width", 320)
frame = np.zeros(
(num_rows * self.render_height, num_columns * self.render_width, 3),
(num_rows * height, num_columns * width, 3),
dtype=np.uint8,
)
assert (
"camera_id" not in self.render_kwargs
), "The camera_id is specified in `multi_camera` render so don't include it in the render_kwargs"
for col in range(num_columns):
for row in range(num_rows):
camera_id = row * num_columns + col
if camera_id >= num_cameras:
break
subframe = physics.render(
height=self.render_height,
width=self.render_width,
camera_id=camera_id,
scene_callback=self.render_scene_callback,
**self.render_kwargs,
)
frame[
row * self.render_height : (row + 1) * self.render_height,
col * self.render_width : (col + 1) * self.render_width,
row * height : (row + 1) * height,
col * width : (col + 1) * width,
] = subframe
return frame

Expand Down
22 changes: 14 additions & 8 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def test_pickle(env_id):
env_2.close()


@pytest.mark.parametrize("camera_id", [0, 1])
@pytest.mark.parametrize("camera_id", [-1, 0, 1])
def test_rendering_camera_id(camera_id):
"""Test that dm-control rendering works."""
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="rgb_array",
camera_id=camera_id,
render_kwargs=dict(camera_id=camera_id),
)
env.reset()
frames = []
Expand All @@ -157,8 +157,10 @@ def test_rendering_multiple_cameras(height, width):
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="multi_camera",
render_height=height,
render_width=width,
render_kwargs=dict(
height=height,
width=width,
),
)
env.reset()
frames = []
Expand All @@ -175,8 +177,10 @@ def test_rendering_depth(height, width):
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="depth_array",
render_height=height,
render_width=width,
render_kwargs=dict(
height=height,
width=width,
),
)
env.reset()
frames = []
Expand All @@ -193,8 +197,10 @@ def test_render_height_widths(height, width):
env = gym.make(
DM_CONTROL_ENV_IDS[0],
render_mode="rgb_array",
render_height=height,
render_width=width,
render_kwargs=dict(
height=height,
width=width,
),
)
env.reset()
frame = env.render()
Expand Down

0 comments on commit 69bcfbf

Please sign in to comment.