Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(snake): define viewer outside the env class #134

Merged
merged 3 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 21 additions & 147 deletions jumanji/environments/routing/snake/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple

import chex
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.animation
import matplotlib.artist
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Rectangle

import jumanji
import jumanji.environments
from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.routing.snake.types import Observation, Position, State
from jumanji.environments.routing.snake.viewer import SnakeViewer
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer


class Snake(Environment[State]):
Expand Down Expand Up @@ -92,28 +90,30 @@ class Snake(Environment[State]):
```
"""

FIGURE_NAME = "Snake"
FIGURE_SIZE = (6.0, 6.0)
MOVES = jnp.array([[-1, 0], [0, 1], [1, 0], [0, -1]], jnp.int32)

def __init__(self, num_rows: int = 12, num_cols: int = 12, time_limit: int = 4000):
def __init__(
self,
num_rows: int = 12,
num_cols: int = 12,
time_limit: int = 4000,
viewer: Optional[Viewer[State]] = None,
):
"""Instantiates a `Snake` environment.

Args:
num_rows: number of rows of the 2D grid. Defaults to 12.
num_cols: number of columns of the 2D grid. Defaults to 12.
time_limit: time_limit of an episode, i.e. number of environment steps before
the episode ends. Defaults to 4000.
viewer: `Viewer` used for rendering. Defaults to `SnakeViewer`.
"""
super().__init__()
self.num_rows = num_rows
self.num_cols = num_cols
self.board_shape = (num_rows, num_cols)
self.time_limit = time_limit

# You must store the created Animation in a variable that lives as long as the animation
# should run. Otherwise, the animation will get garbage-collected.
self._animation: Optional[matplotlib.animation.Animation] = None
self._viewer = viewer or SnakeViewer()

def __repr__(self) -> str:
return "\n".join(
Expand Down Expand Up @@ -380,20 +380,8 @@ def render(self, state: State) -> None:

Args:
state: State object containing the current dynamics of the environment.

"""
self._clear_display()
fig, ax = self._get_fig_ax()
self._draw(ax, state)
self._update_display(fig)

def close(self) -> None:
"""Perform any necessary cleanup.

Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
plt.close(self.FIGURE_NAME)
self._viewer.render(state)

def animate(
self,
Expand All @@ -412,126 +400,12 @@ def animate(
Returns:
Animation object that can be saved as a GIF, MP4, or rendered with HTML.
"""
fig, ax = plt.subplots(num=f"{self.FIGURE_NAME}Anim", figsize=self.FIGURE_SIZE)
self._draw_board(ax)
plt.close(fig)

patches: List[matplotlib.patches.Patch] = []

def make_frame(state: State) -> Any:
while patches:
patches.pop().remove()
patches.extend(self._create_entities(state))
for patch in patches:
ax.add_patch(patch)

# Create the animation object.
matplotlib.rc("animation", html="jshtml")
self._animation = matplotlib.animation.FuncAnimation(
fig,
make_frame,
frames=states,
interval=interval,
)
return self._viewer.animate(states, interval, save_path)

# Save the animation as a gif.
if save_path:
self._animation.save(save_path)

return self._animation

def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
exists = plt.fignum_exists(self.FIGURE_NAME)
if exists:
fig = plt.figure(self.FIGURE_NAME)
ax = fig.get_axes()[0]
else:
fig = plt.figure(self.FIGURE_NAME, figsize=self.FIGURE_SIZE)
fig.set_tight_layout({"pad": False, "w_pad": 0.0, "h_pad": 0.0})
if not plt.isinteractive():
fig.show()
ax = fig.add_subplot()
return fig, ax

def _draw(self, ax: plt.Axes, state: State) -> None:
ax.clear()
self._draw_board(ax)
for patch in self._create_entities(state):
ax.add_patch(patch)

def _draw_board(self, ax: plt.Axes) -> None:
# Draw the square box that delimits the board.
ax.axis("off")
ax.plot([0, 0], [0, self.num_rows], "-k", lw=2)
ax.plot([0, self.num_cols], [self.num_rows, self.num_rows], "-k", lw=2)
ax.plot([self.num_cols, self.num_cols], [self.num_rows, 0], "-k", lw=2)
ax.plot([self.num_cols, 0], [0, 0], "-k", lw=2)

def _create_entities(self, state: State) -> List[matplotlib.patches.Patch]:
"""Loop over the different cells and draws corresponding shapes in the ax object."""
patches = []
linewidth = (
min(
n * size
for n, size in zip((self.num_rows, self.num_cols), self.FIGURE_SIZE)
)
/ 44.0
)
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"", ["yellowgreen", "forestgreen"]
)
for row in range(self.num_rows):
for col in range(self.num_cols):
if state.body_state[row, col]:
body_cell_patch = Rectangle(
(col, self.num_rows - 1 - row),
1,
1,
edgecolor=cmap(1),
facecolor=cmap(state.body_state[row, col] / state.length),
fill=True,
lw=linewidth,
)
patches.append(body_cell_patch)
head_patch = Circle(
(
state.head_position[1] + 0.5,
self.num_rows - 1 - state.head_position[0] + 0.5,
),
0.3,
edgecolor=cmap(0.5),
facecolor=cmap(0),
fill=True,
lw=linewidth,
)
patches.append(head_patch)
fruit_patch = Circle(
(
state.fruit_position[1] + 0.5,
self.num_rows - 1 - state.fruit_position[0] + 0.5,
),
0.2,
edgecolor="brown",
facecolor="lightcoral",
fill=True,
lw=linewidth,
)
patches.append(fruit_patch)
return patches

def _update_display(self, fig: plt.Figure) -> None:
if plt.isinteractive():
# Required to update render when using Jupyter Notebook.
fig.canvas.draw()
if jumanji.environments.is_colab():
plt.show(self.FIGURE_NAME)
else:
# Required to update render when not using Jupyter Notebook.
fig.canvas.draw_idle()
fig.canvas.flush_events()

def _clear_display(self) -> None:
if jumanji.environments.is_colab():
import IPython.display

IPython.display.clear_output(True)
def close(self) -> None:
"""Perform any necessary cleanup.

Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
self._viewer.close()
Loading