Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for Solara deepcopy bug #2460

Merged
merged 17 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mesa/examples/basic/boltzmann_wealth_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
"max": 100,
"step": 1,
},
"seed": {
"type": "InputText",
"value": 42,
"label": "Random Seed",
},
"width": 10,
"height": 10,
}
Expand All @@ -30,7 +35,7 @@


# Create initial model instance
model1 = BoltzmannWealthModel(50, 10, 10)
model = BoltzmannWealthModel(50, 10, 10)

Check warning on line 38 in mesa/examples/basic/boltzmann_wealth_model/app.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/boltzmann_wealth_model/app.py#L38

Added line #L38 was not covered by tests

# Create visualization elements. The visualization elements are solara components
# that receive the model instance as a "prop" and display it in a certain way.
Expand All @@ -49,7 +54,7 @@
# solara run app.py
# It will automatically update and display any changes made to this file
page = SolaraViz(
model1,
model,
components=[SpaceGraph, GiniPlot],
model_params=model_params,
name="Boltzmann Wealth Model",
Expand Down
93 changes: 52 additions & 41 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from __future__ import annotations

import asyncio
import copy
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal
Expand All @@ -48,7 +47,6 @@ def SolaraViz(
| Literal["default"] = "default",
play_interval: int = 100,
model_params=None,
seed: float = 0,
name: str | None = None,
):
"""Solara visualization component.
Expand All @@ -69,8 +67,6 @@ def SolaraViz(
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
seed (int, optional): Seed for the random number generator. This ensures reproducibility
of the model's behavior. Defaults to 0.
name (str | None, optional): Name of the visualization. Defaults to the models class name.

Returns:
Expand All @@ -88,7 +84,9 @@ def SolaraViz(
value results in faster stepping, while a higher value results in slower stepping.
"""
if components == "default":
components = [components_altair.make_space_altair()]
components = [components_altair.make_altair_space()]
if model_params is None:
model_params = {}

# Convert model to reactive
if not isinstance(model, solara.Reactive):
Expand All @@ -109,20 +107,23 @@ def step():

solara.use_effect(connect_to_model, [model.value])

# set up reactive model_parameters shared by ModelCreator and ModelController
reactive_model_parameters = solara.use_reactive({})

with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)

with solara.Sidebar(), solara.Column():
with solara.Card("Controls"):
ModelController(model, play_interval)

if model_params is not None:
with solara.Card("Model Parameters"):
ModelCreator(
model,
model_params,
seed=seed,
)
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=play_interval,
)
with solara.Card("Model Parameters"):
ModelCreator(
model, model_params, model_parameters=reactive_model_parameters
)
with solara.Card("Information"):
ShowSteps(model.value)

Expand Down Expand Up @@ -173,24 +174,23 @@ def ComponentsView(


@solara.component
def ModelController(model: solara.Reactive[Model], play_interval=100):
def ModelController(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Maybe more of a clarification question: I previously suggested adding a * here to have keyword-only arguments. Maybe I am wrong here and thats why you ignored it, but my understanding was that this is how it works

def foo(bar, baz=2):
    ...
# all of this works:
foo(1)
foo(3, baz=4)
foo(5, 6)
# Compared to
def foo2(bar, *, baz):
    pass

foo2(1) # works
foo2(3, baz=4) # works
foo2(3, 4) # doesnt work

My thinking behind this was to have API stability. If we later change the signature to

def ModelController(model, awesome_new_arg, model_parameters=None)

we don't have a breaking change. right now it would be breaking if model_parameters is only provided as a positional argument.

Does that make sense? Is that how it works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering what you ment, but this makes good sense. I'll update it accordingly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, feel free to merge afterwards

model: solara.Reactive[Model],
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int = 100,
):
"""Create controls for model execution (step, play, pause, reset).

Args:
model (solara.Reactive[Model]): Reactive model instance
play_interval (int, optional): Interval for playing the model steps in milliseconds.
model: Reactive model instance
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.

"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)
original_model = solara.use_reactive(None)

def save_initial_model():
"""Save the initial model for comparison."""
original_model.set(copy.deepcopy(model.value))
playing.value = False
force_update()

solara.use_effect(save_initial_model, [model.value])
if model_parameters is None:
model_parameters = solara.use_reactive({})

async def step():
while playing.value and running.value:
Expand All @@ -210,7 +210,7 @@ def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
model.value = copy.deepcopy(original_model.value)
model.value = model.value = model.value.__class__(**model_parameters.value)

def do_play_pause():
"""Toggle play/pause."""
Expand Down Expand Up @@ -269,17 +269,21 @@ def check_param_is_fixed(param):


@solara.component
def ModelCreator(model, model_params, seed=1):
def ModelCreator(
model: solara.Reactive[Model],
user_params: dict,
model_parameters: dict | solara.Reactive[dict] = None,
):
"""Solara component for creating and managing a model instance with user-defined parameters.

This component allows users to create a model instance with specified parameters and seed.
It provides an interface for adjusting model parameters and reseeding the model's random
number generator.

Args:
model (solara.Reactive[Model]): A reactive model instance. This is the main model to be created and managed.
model_params (dict): Dictionary of model parameters. This includes both user-adjustable parameters and fixed parameters.
seed (int, optional): Initial seed for the random number generator. Defaults to 1.
model: A reactive model instance. This is the main model to be created and managed.
user_params: Parameters for (re-)instantiating a model. Can include user-adjustable parameters and fixed parameters. Defaults to None.
model_parameters: reactive parameters for reinitializing the model

Returns:
solara.component: A Solara component that renders the model creation and management interface.
Expand All @@ -300,24 +304,25 @@ def ModelCreator(model, model_params, seed=1):
- The component provides an interface for adjusting user-defined parameters and reseeding the model.

"""
if model_parameters is None:
model_parameters = solara.use_reactive({})

solara.use_effect(
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
[model.value],
)
user_params, fixed_params = split_model_params(user_params)

user_params, fixed_params = split_model_params(model_params)

model_parameters, set_model_parameters = solara.use_state(
{
**fixed_params,
**{k: v.get("value") for k, v in user_params.items()},
}
)
# set model_parameters to the default values for all parameters
model_parameters.value = {
**fixed_params,
**{k: v.get("value") for k, v in user_params.items()},
}

def on_change(name, value):
new_model_parameters = {**model_parameters, name: value}
new_model_parameters = {**model_parameters.value, name: value}
model.value = model.value.__class__(**new_model_parameters)
set_model_parameters(new_model_parameters)
model_parameters.value = new_model_parameters

UserInputs(user_params, on_change=on_change)

Expand Down Expand Up @@ -409,6 +414,12 @@ def change_handler(value, name=name):
on_value=change_handler,
value=options.get("value"),
)
elif input_type == "InputText":
solara.InputText(
label=label,
on_value=change_handler,
value=options.get("value"),
)
else:
raise ValueError(f"{input_type} is not a supported input type")

Expand Down
11 changes: 9 additions & 2 deletions tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def test_call_space_drawer(mocker): # noqa: D103
mesa.visualization.components.altair_components, "SpaceAltair"
)

model = mesa.Model()
class MockModel(mesa.Model):
def __init__(self, seed=None):
super().__init__(seed=seed)

model = MockModel()
mocker.patch.object(mesa.Model, "__init__", return_value=None)

agent_portrayal = {
Expand All @@ -112,7 +116,10 @@ def test_call_space_drawer(mocker): # noqa: D103
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
solara.render(
SolaraViz(model, components=[make_mpl_space_component(agent_portrayal)])
SolaraViz(
model,
components=[make_mpl_space_component(agent_portrayal)],
)
)
# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(
Expand Down
Loading