diff --git a/benchmarks/WolfSheep/wolf_sheep.py b/benchmarks/WolfSheep/wolf_sheep.py index ecb78594272..655d51bba63 100644 --- a/benchmarks/WolfSheep/wolf_sheep.py +++ b/benchmarks/WolfSheep/wolf_sheep.py @@ -227,8 +227,8 @@ def __init__( patch.move_to(cell) def step(self): - self.get_agents_of_type(Sheep).shuffle(inplace=True).do("step") - self.get_agents_of_type(Wolf).shuffle(inplace=True).do("step") + self.agents_by_type[Sheep].shuffle(inplace=True).do("step") + self.agents_by_type[Wolf].shuffle(inplace=True).do("step") if __name__ == "__main__": diff --git a/mesa/experimental/devs/examples/wolf_sheep.py b/mesa/experimental/devs/examples/wolf_sheep.py index 9fa6b3d96c7..99fde6a9c8f 100644 --- a/mesa/experimental/devs/examples/wolf_sheep.py +++ b/mesa/experimental/devs/examples/wolf_sheep.py @@ -231,8 +231,8 @@ def __init__( self.grid.place_agent(patch, pos) def step(self): - self.get_agents_of_type(Sheep).shuffle(inplace=True).do("step") - self.get_agents_of_type(Wolf).shuffle(inplace=True).do("step") + self.agents_by_type[Sheep].shuffle(inplace=True).do("step") + self.agents_by_type[Wolf].shuffle(inplace=True).do("step") if __name__ == "__main__": diff --git a/mesa/model.py b/mesa/model.py index 9387e6d1031..f276cc82284 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -33,11 +33,13 @@ class Model: Properties: agents: An AgentSet containing all agents in the model agent_types: A list of different agent types present in the model. + agents_by_type: A dictionary where the keys are agent types and the values are the corresponding AgentSets. steps: An integer representing the number of steps the model has taken. It increases automatically at the start of each step() call. Methods: get_agents_of_type: Returns an AgentSet of agents of the specified type. + Deprecated: Use agents_by_type[agenttype] instead. run_model: Runs the model's simulation until a defined end condition is reached. step: Executes a single step of the model's simulation process. next_id: Generates and returns the next unique identifier for an agent. @@ -106,24 +108,26 @@ def agent_types(self) -> list[type]: """Return a list of all unique agent types registered with the model.""" return list(self._agents_by_type.keys()) - def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet: - """Retrieves an AgentSet containing all agents of the specified type. - - Args: - agenttype: The type of agent to retrieve. - - Raises: - KeyError: If agenttype does not exist - + @property + def agents_by_type(self) -> dict[type[Agent], AgentSet]: + """A dictionary where the keys are agent types and the values are the corresponding AgentSets.""" + return self._agents_by_type - """ - return self._agents_by_type[agenttype] + def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet: + """Deprecated: Retrieves an AgentSet containing all agents of the specified type.""" + warnings.warn( + f"Model.get_agents_of_type() is deprecated, please replace get_agents_of_type({agenttype})" + f"with the property agents_by_type[{agenttype}].", + DeprecationWarning, + stacklevel=2, + ) + return self.agents_by_type[agenttype] def _setup_agent_registration(self): """helper method to initialize the agent registration datastructures""" self._agents = {} # the hard references to all agents in the model self._agents_by_type: dict[ - type, AgentSet + type[Agent], AgentSet ] = {} # a dict with an agentset for each class of agents self._all_agents = AgentSet([], self) # an agenset with all agents diff --git a/tests/test_model.py b/tests/test_model.py index 837d6b0049a..c92fcb2e4a9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,4 @@ -from mesa.agent import Agent +from mesa.agent import Agent, AgentSet from mesa.model import Model @@ -53,3 +53,19 @@ class TestAgent(Agent): test_agent = TestAgent(model.next_id(), model) assert test_agent in model.agents assert type(test_agent) in model.agent_types + + +def test_agents_by_type(): + class Wolf(Agent): + pass + + class Sheep(Agent): + pass + + model = Model() + wolf = Wolf(1, model) + sheep = Sheep(2, model) + + assert model.agents_by_type[Wolf] == AgentSet([wolf], model) + assert model.agents_by_type[Sheep] == AgentSet([sheep], model) + assert len(model.agents_by_type) == 2 diff --git a/tests/test_time.py b/tests/test_time.py index a9d8e2e6055..4203ca5d222 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -320,7 +320,7 @@ def test_random_activation_counts(self): agent_types = model.agent_types for agent_type in agent_types: assert model.schedule.get_type_count(agent_type) == len( - model.get_agents_of_type(agent_type) + model.agents_by_type[agent_type] ) # def test_add_non_unique_ids(self):