diff --git a/mesa/agent.py b/mesa/agent.py index e29b233cffc..0bb2f21875a 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -7,10 +7,15 @@ # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations +import contextlib +import operator +import warnings +import weakref +from collections.abc import MutableSet, Sequence from random import Random # mypy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -41,12 +46,13 @@ def __init__(self, unique_id: int, model: Model) -> None: self.model = model self.pos: Position | None = None - # Register the agent with the model using defaultdict - self.model.agents[type(self)][self] = None + # register agent + self.model._agents[type(self)][self] = None def remove(self) -> None: """Remove and delete the agent from the model.""" - self.model.agents[type(self)].pop(self) + with contextlib.suppress(KeyError): + self.model._agents[type(self)].pop(self) def step(self) -> None: """A single step of the agent.""" @@ -57,3 +63,279 @@ def advance(self) -> None: @property def random(self) -> Random: return self.model.random + + +class AgentSet(MutableSet, Sequence): + """ + .. warning:: + The AgentSet is experimental. It may be changed or removed in any and all future releases, including + patch releases. + We would love to hear what you think about this new feature. If you have any thoughts, share them with + us here: https://github.com/projectmesa/mesa/discussions/1919 + + A collection class that represents an ordered set of agents within an agent-based model (ABM). This class + extends both MutableSet and Sequence, providing set-like functionality with order preservation and + sequence operations. + + Attributes: + model (Model): The ABM model instance to which this AgentSet belongs. + + Methods: + __len__, __iter__, __contains__, select, shuffle, sort, _update, do, get, __getitem__, + add, discard, remove, __getstate__, __setstate__, random + + Note: + The AgentSet maintains weak references to agents, allowing for efficient management of agent lifecycles + without preventing garbage collection. It is associated with a specific model instance, enabling + interactions with the model's environment and other agents.The implementation uses a WeakKeyDictionary to store agents, + which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet. + """ + + def __init__(self, agents: Iterable[Agent], model: Model): + """ + Initializes the AgentSet with a collection of agents and a reference to the model. + + Args: + agents (Iterable[Agent]): An iterable of Agent objects to be included in the set. + model (Model): The ABM model instance to which this AgentSet belongs. + """ + self.model = model + + if not self.model.agentset_experimental_warning_given: + self.model.agentset_experimental_warning_given = True + warnings.warn( + "The AgentSet is experimental. It may be changed or removed in any and all future releases, including patch releases.\n" + "We would love to hear what you think about this new feature. If you have any thoughts, share them with us here: https://github.com/projectmesa/mesa/discussions/1919", + FutureWarning, + stacklevel=2, + ) + + self._agents = weakref.WeakKeyDictionary() + for agent in agents: + self._agents[agent] = None + + def __len__(self) -> int: + """Return the number of agents in the AgentSet.""" + return len(self._agents) + + def __iter__(self) -> Iterator[Agent]: + """Provide an iterator over the agents in the AgentSet.""" + return self._agents.keys() + + def __contains__(self, agent: Agent) -> bool: + """Check if an agent is in the AgentSet. Can be used like `agent in agentset`.""" + return agent in self._agents + + def select( + self, + filter_func: Callable[[Agent], bool] | None = None, + n: int = 0, + inplace: bool = False, + agent_type: type[Agent] | None = None, + ) -> AgentSet: + """ + Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. + + Args: + filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the + agent should be included in the result. Defaults to None, meaning no filtering is applied. + n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0. + inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False. + agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied. + + Returns: + AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated. + """ + + def agent_generator(): + count = 0 + for agent in self: + if filter_func and not filter_func(agent): + continue + if agent_type and not isinstance(agent, agent_type): + continue + yield agent + count += 1 + # default of n is zero, zo evaluates to False + if n and count >= n: + break + + agents = agent_generator() + + return AgentSet(agents, self.model) if not inplace else self._update(agents) + + def shuffle(self, inplace: bool = False) -> AgentSet: + """ + Randomly shuffle the order of agents in the AgentSet. + + Args: + inplace (bool, optional): If True, shuffles the agents in the current AgentSet; otherwise, returns a new shuffled AgentSet. Defaults to False. + + Returns: + AgentSet: A shuffled AgentSet. Returns the current AgentSet if inplace is True. + """ + shuffled_agents = list(self) + self.random.shuffle(shuffled_agents) + + return ( + AgentSet(shuffled_agents, self.model) + if not inplace + else self._update(shuffled_agents) + ) + + def sort( + self, + key: Callable[[Agent], Any] | str, + ascending: bool = False, + inplace: bool = False, + ) -> AgentSet: + """ + Sort the agents in the AgentSet based on a specified attribute or custom function. + + Args: + key (Callable[[Agent], Any] | str): A function or attribute name based on which the agents are sorted. + ascending (bool, optional): If True, the agents are sorted in ascending order. Defaults to False. + inplace (bool, optional): If True, sorts the agents in the current AgentSet; otherwise, returns a new sorted AgentSet. Defaults to False. + + Returns: + AgentSet: A sorted AgentSet. Returns the current AgentSet if inplace is True. + """ + if isinstance(key, str): + key = operator.attrgetter(key) + + sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending) + + return ( + AgentSet(sorted_agents, self.model) + if not inplace + else self._update(sorted_agents) + ) + + def _update(self, agents: Iterable[Agent]): + """Update the AgentSet with a new set of agents. + This is a private method primarily used internally by other methods like select, shuffle, and sort. + """ + _agents = weakref.WeakKeyDictionary() + for agent in agents: + _agents[agent] = None + + self._agents = _agents + return self + + def do( + self, method_name: str, *args, return_results: bool = False, **kwargs + ) -> AgentSet | list[Any]: + """ + Invoke a method on each agent in the AgentSet. + + Args: + method_name (str): The name of the method to call on each agent. + return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls. + *args: Variable length argument list passed to the method being called. + **kwargs: Arbitrary keyword arguments passed to the method being called. + + Returns: + AgentSet | list[Any]: The results of the method calls if return_results is True, otherwise the AgentSet itself. + """ + res = [getattr(agent, method_name)(*args, **kwargs) for agent in self._agents] + + return res if return_results else self + + def get(self, attr_name: str) -> list[Any]: + """ + Retrieve a specified attribute from each agent in the AgentSet. + + Args: + attr_name (str): The name of the attribute to retrieve from each agent. + + Returns: + list[Any]: A list of attribute values from each agent in the set. + """ + return [getattr(agent, attr_name) for agent in self._agents] + + def __getitem__(self, item: int | slice) -> Agent: + """ + Retrieve an agent or a slice of agents from the AgentSet. + + Args: + item (int | slice): The index or slice for selecting agents. + + Returns: + Agent | list[Agent]: The selected agent or list of agents based on the index or slice provided. + """ + return list(self._agents.keys())[item] + + def add(self, agent: Agent): + """ + Add an agent to the AgentSet. + + Args: + agent (Agent): The agent to add to the set. + + Note: + This method is an implementation of the abstract method from MutableSet. + """ + self._agents[agent] = None + + def discard(self, agent: Agent): + """ + Remove an agent from the AgentSet if it exists. + + This method does not raise an error if the agent is not present. + + Args: + agent (Agent): The agent to remove from the set. + + Note: + This method is an implementation of the abstract method from MutableSet. + """ + with contextlib.suppress(KeyError): + del self._agents[agent] + + def remove(self, agent: Agent): + """ + Remove an agent from the AgentSet. + + This method raises an error if the agent is not present. + + Args: + agent (Agent): The agent to remove from the set. + + Note: + This method is an implementation of the abstract method from MutableSet. + """ + del self._agents[agent] + + def __getstate__(self): + """ + Retrieve the state of the AgentSet for serialization. + + Returns: + dict: A dictionary representing the state of the AgentSet. + """ + return {"agents": list(self._agents.keys()), "model": self.model} + + def __setstate__(self, state): + """ + Set the state of the AgentSet during deserialization. + + Args: + state (dict): A dictionary representing the state to restore. + """ + self.model = state["model"] + self._update(state["agents"]) + + @property + def random(self) -> Random: + """ + Provide access to the model's random number generator. + + Returns: + Random: The random number generator associated with the model. + """ + return self.model.random + + +# consider adding for performance reasons +# for Sequence: __reversed__, index, and count +# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ diff --git a/mesa/model.py b/mesa/model.py index 96398b446f8..22e654483dd 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -7,12 +7,14 @@ # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations +import itertools import random from collections import defaultdict # mypy from typing import Any +from mesa.agent import AgentSet from mesa.datacollection import DataCollector @@ -27,8 +29,20 @@ class Model: running: A boolean indicating if the model should continue running. schedule: An object to manage the order and execution of agent steps. current_id: A counter for assigning unique IDs to agents. - agents: A defaultdict mapping each agent type to a dict of its instances. - Agent instances are saved in the nested dict keys, with the values being None. + _agents: A defaultdict mapping each agent type to a dict of its instances. + This private attribute is used internally to manage agents. + + Properties: + agents: An AgentSet containing all agents in the model, generated from the _agents attribute. + agent_types: A list of different agent types present in the model. + + Methods: + get_agents_of_type: Returns an AgentSet of agents of the specified type. + run_model: Runs the model's simulation until a defined end condition is reached. + step: Executes a single step of the model's simulation process. + next_id: Generates and returns the next unique identifier for an agent. + reset_randomizer: Resets the model's random number generator with a new or existing seed. + initialize_data_collector: Sets up the data collector for the model, requiring an initialized scheduler and agents. """ def __new__(cls, *args: Any, **kwargs: Any) -> Any: @@ -51,12 +65,27 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.running = True self.schedule = None self.current_id = 0 - self.agents: defaultdict[type, dict] = defaultdict(dict) + self._agents: defaultdict[type, dict] = defaultdict(dict) + + # Warning flags for current experimental features. These make sure a warning is only printed once per model. + self.agentset_experimental_warning_given = False @property - def agent_types(self) -> list: + def agents(self) -> AgentSet: + """Provides an AgentSet of all agents in the model, combining agents from all types.""" + all_agents = itertools.chain( + *(agents_by_type.keys() for agents_by_type in self._agents.values()) + ) + return AgentSet(all_agents, self) + + @property + def agent_types(self) -> list[type]: """Return a list of different agent types.""" - return list(self.agents.keys()) + return list(self._agents.keys()) + + def get_agents_of_type(self, agenttype: type) -> AgentSet: + """Retrieves an AgentSet containing all agents of the specified type.""" + return AgentSet(self._agents[agenttype].values(), self) def run_model(self) -> None: """Run the model until the end condition is reached. Overload as diff --git a/tests/test_agent.py b/tests/test_agent.py index 561208f25af..3cc8f9adc4e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,16 +1,208 @@ -from mesa.agent import Agent +import pickle + +import pytest + +from mesa.agent import Agent, AgentSet from mesa.model import Model -def test_agent_removal(): - class TestAgent(Agent): - pass +class TestAgent(Agent): + def get_unique_identifier(self): + return self.unique_id + +def test_agent_removal(): model = Model() agent = TestAgent(model.next_id(), model) # Check if the agent is added - assert agent in model.agents[type(agent)] + assert agent in model.agents agent.remove() # Check if the agent is removed - assert agent not in model.agents[type(agent)] + assert agent not in model.agents + + +def test_agentset(): + # create agentset + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + + agentset = AgentSet(agents, model) + + assert agents[0] in agentset + assert len(agentset) == len(agents) + assert all(a1 == a2 for a1, a2 in zip(agentset[0:5], agents[0:5])) + + for a1, a2 in zip(agentset, agents): + assert a1 == a2 + + def test_function(agent): + return agent.unique_id > 5 + + assert len(agentset.select(test_function)) == 5 + assert len(agentset.select(test_function, n=2)) == 2 + assert len(agentset.select(test_function, inplace=True)) == 5 + assert agentset.select(inplace=True) == agentset + assert all(a1 == a2 for a1, a2 in zip(agentset.select(), agentset)) + assert all(a1 == a2 for a1, a2 in zip(agentset.select(n=5), agentset[:5])) + + assert len(agentset.shuffle().select(n=5)) == 5 + + def test_function(agent): + return agent.unique_id + + assert all( + a1 == a2 + for a1, a2 in zip(agentset.sort(test_function, ascending=False), agentset[::-1]) + ) + assert all( + a1 == a2 + for a1, a2 in zip(agentset.sort("unique_id", ascending=False), agentset[::-1]) + ) + + assert all( + a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset) + ) + assert all( + a1 == a2.unique_id + for a1, a2 in zip( + agentset.do("get_unique_identifier", return_results=True), agentset + ) + ) + assert agentset == agentset.do("get_unique_identifier") + + agentset.discard(agents[0]) + assert agents[0] not in agentset + agentset.discard(agents[0]) # check if no error is raised on discard + + with pytest.raises(KeyError): + agentset.remove(agents[0]) + + agentset.add(agents[0]) + assert agents[0] in agentset + + # because AgentSet uses weakrefs, we need hard refs as well.... + other_agents, another_set = pickle.loads( # noqa: S301 + pickle.dumps([agents, AgentSet(agents, model)]) + ) + assert all( + a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents) + ) + assert len(another_set) == len(other_agents) + + +def test_agentset_initialization(): + model = Model() + empty_agentset = AgentSet([], model) + assert len(empty_agentset) == 0 + + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + assert len(agentset) == 10 + + +def test_agentset_serialization(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(5)] + agentset = AgentSet(agents, model) + + serialized = pickle.dumps(agentset) + deserialized = pickle.loads(serialized) # noqa: S301 + + original_ids = [agent.unique_id for agent in agents] + deserialized_ids = [agent.unique_id for agent in deserialized] + + assert deserialized_ids == original_ids + + +def test_agent_membership(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(5)] + agentset = AgentSet(agents, model) + + assert agents[0] in agentset + assert TestAgent(model.next_id(), model) not in agentset + + +def test_agent_add_remove_discard(): + model = Model() + agent = TestAgent(model.next_id(), model) + agentset = AgentSet([], model) + + agentset.add(agent) + assert agent in agentset + + agentset.remove(agent) + assert agent not in agentset + + agentset.add(agent) + agentset.discard(agent) + assert agent not in agentset + + with pytest.raises(KeyError): + agentset.remove(agent) + + +def test_agentset_get_item(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + assert agentset[0] == agents[0] + assert agentset[-1] == agents[-1] + assert agentset[1:3] == agents[1:3] + + with pytest.raises(IndexError): + _ = agentset[20] + + +def test_agentset_do_method(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + with pytest.raises(AttributeError): + agentset.do("non_existing_method") + + +def test_agentset_get_attribute(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + unique_ids = agentset.get("unique_id") + assert unique_ids == [agent.unique_id for agent in agents] + + with pytest.raises(AttributeError): + agentset.get("non_existing_attribute") + + +class OtherAgentType(Agent): + def get_unique_identifier(self): + return self.unique_id + + +def test_agentset_select_by_type(): + model = Model() + # Create a mix of agents of two different types + test_agents = [TestAgent(model.next_id(), model) for _ in range(4)] + other_agents = [OtherAgentType(model.next_id(), model) for _ in range(6)] + + # Combine the two types of agents + mixed_agents = test_agents + other_agents + agentset = AgentSet(mixed_agents, model) + + # Test selection by type + selected_test_agents = agentset.select(agent_type=TestAgent) + assert len(selected_test_agents) == len(test_agents) + assert all(isinstance(agent, TestAgent) for agent in selected_test_agents) + assert len(selected_test_agents) == 4 + + selected_other_agents = agentset.select(agent_type=OtherAgentType) + assert len(selected_other_agents) == len(other_agents) + assert all(isinstance(agent, OtherAgentType) for agent in selected_other_agents) + assert len(selected_other_agents) == 6 + + # Test with no type specified (should select all agents) + all_agents = agentset.select() + assert len(all_agents) == len(mixed_agents) diff --git a/tests/test_model.py b/tests/test_model.py index c54634d8352..874d45f935f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -49,5 +49,5 @@ class TestAgent(Agent): model = Model() test_agent = TestAgent(model.next_id(), model) - assert test_agent in model.agents[type(test_agent)] + assert test_agent in model.agents assert type(test_agent) in model.agent_types