Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for drawing discrete grids #2386

Merged
merged 8 commits into from
Oct 21, 2024
63 changes: 52 additions & 11 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from matplotlib.figure import Figure

import mesa
from mesa.experimental.cell_space import VoronoiGrid
from mesa.experimental.cell_space import OrthogonalMooreGrid, VoronoiGrid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also work for OrthogonalVonNeumannGrid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'll fix that! Just need to figure out how to capture this in the match case statement.

I guess it even will work for hexgrids (although we might want to visualize these a bit differently with offsets for every odd row.

from mesa.space import PropertyLayer
from mesa.visualization.utils import update_counter

Expand Down Expand Up @@ -52,16 +52,19 @@ def SpaceMatplotlib(
if space is None:
space = getattr(model, "space", None)

if isinstance(space, mesa.space._Grid):
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @quaquel, I understand that this is already merged. May I ask whether it is intentional to remove the call to _draw_grid() here? Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I made a mistake. I'll put in a PR to fix this asap.

Thanks for the post-merge review!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and fixed via #2398

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks!

elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal, model)
elif isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, VoronoiGrid):
_draw_voronoi(space, space_ax, agent_portrayal)
elif space is None and propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)
# https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching
match space:
case mesa.space._Grid():
_draw_continuous_space(space, space_ax, agent_portrayal, model)
case mesa.space.NetworkGrid():
_draw_network_grid(space, space_ax, agent_portrayal)
case VoronoiGrid():
_draw_voronoi(space, space_ax, agent_portrayal)
case OrthogonalMooreGrid():
_draw_discrete_space_grid(space, space_ax, agent_portrayal)
case None:
if propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)

solara.FigureMatplotlib(
space_fig, format="png", bbox_inches="tight", dependencies=dependencies
Expand Down Expand Up @@ -291,6 +294,44 @@ def portray(g):
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black


def _draw_discrete_space_grid(space: OrthogonalMooreGrid, space_ax, agent_portrayal):
if space._ndims != 2:
raise ValueError("Space must be 2D")

def portray(g):
x = []
y = []
s = [] # size
c = [] # color

for cell in g.all_cells:
for agent in cell.agents:
data = agent_portrayal(agent)
x.append(cell.coordinate[0])
y.append(cell.coordinate[1])
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
out["s"] = s
if len(c) > 0:
out["c"] = c

return out

space_ax.set_xlim(0, space.width)
space_ax.set_ylim(0, space.height)

# Draw grid lines
for x in range(space.width + 1):
space_ax.axvline(x, color="gray", linestyle=":")
for y in range(space.height + 1):
space_ax.axhline(y, color="gray", linestyle=":")

space_ax.scatter(**portray(space))


def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
"""Create a plotting function for a specified measure.

Expand Down
Loading