Skip to content

Commit

Permalink
Model: Replace get_agents_of_type method with agents_by_type prop…
Browse files Browse the repository at this point in the history
…erty (#2267)

This PR replaces the Model method `get_agents_of_type()` with an `agents_by_type` property, which directly returns the dict.

Instead of using:
```Python
model.get_agents_of_type(Sheep)
```
You should now use:
```Python
model.agents_of_type[Sheep]
```
Since we also use `agents` and not `get_agents`, it's more intuitive to directly have access to the object itself (in this case the `agents_by_type` dict). It's also is more concise and since you have full dict access, more flexible.

Examples and tests are updated and an deprecation warning is added for `get_agents_of_type()`.
  • Loading branch information
EwoutH authored Sep 3, 2024
1 parent 002d3f4 commit 221084d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 18 deletions.
4 changes: 2 additions & 2 deletions benchmarks/WolfSheep/wolf_sheep.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def __init__(
patch.move_to(cell)

def step(self):
self.get_agents_of_type(Sheep).shuffle(inplace=True).do("step")
self.get_agents_of_type(Wolf).shuffle(inplace=True).do("step")
self.agents_by_type[Sheep].shuffle(inplace=True).do("step")
self.agents_by_type[Wolf].shuffle(inplace=True).do("step")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions mesa/experimental/devs/examples/wolf_sheep.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def __init__(
self.grid.place_agent(patch, pos)

def step(self):
self.get_agents_of_type(Sheep).shuffle(inplace=True).do("step")
self.get_agents_of_type(Wolf).shuffle(inplace=True).do("step")
self.agents_by_type[Sheep].shuffle(inplace=True).do("step")
self.agents_by_type[Wolf].shuffle(inplace=True).do("step")


if __name__ == "__main__":
Expand Down
28 changes: 16 additions & 12 deletions mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ class Model:
Properties:
agents: An AgentSet containing all agents in the model
agent_types: A list of different agent types present in the model.
agents_by_type: A dictionary where the keys are agent types and the values are the corresponding AgentSets.
steps: An integer representing the number of steps the model has taken.
It increases automatically at the start of each step() call.
Methods:
get_agents_of_type: Returns an AgentSet of agents of the specified type.
Deprecated: Use agents_by_type[agenttype] instead.
run_model: Runs the model's simulation until a defined end condition is reached.
step: Executes a single step of the model's simulation process.
next_id: Generates and returns the next unique identifier for an agent.
Expand Down Expand Up @@ -106,24 +108,26 @@ def agent_types(self) -> list[type]:
"""Return a list of all unique agent types registered with the model."""
return list(self._agents_by_type.keys())

def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
"""Retrieves an AgentSet containing all agents of the specified type.
Args:
agenttype: The type of agent to retrieve.
Raises:
KeyError: If agenttype does not exist
@property
def agents_by_type(self) -> dict[type[Agent], AgentSet]:
"""A dictionary where the keys are agent types and the values are the corresponding AgentSets."""
return self._agents_by_type

"""
return self._agents_by_type[agenttype]
def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
"""Deprecated: Retrieves an AgentSet containing all agents of the specified type."""
warnings.warn(
f"Model.get_agents_of_type() is deprecated, please replace get_agents_of_type({agenttype})"
f"with the property agents_by_type[{agenttype}].",
DeprecationWarning,
stacklevel=2,
)
return self.agents_by_type[agenttype]

def _setup_agent_registration(self):
"""helper method to initialize the agent registration datastructures"""
self._agents = {} # the hard references to all agents in the model
self._agents_by_type: dict[
type, AgentSet
type[Agent], AgentSet
] = {} # a dict with an agentset for each class of agents
self._all_agents = AgentSet([], self) # an agenset with all agents

Expand Down
18 changes: 17 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mesa.agent import Agent
from mesa.agent import Agent, AgentSet
from mesa.model import Model


Expand Down Expand Up @@ -53,3 +53,19 @@ class TestAgent(Agent):
test_agent = TestAgent(model.next_id(), model)
assert test_agent in model.agents
assert type(test_agent) in model.agent_types


def test_agents_by_type():
class Wolf(Agent):
pass

class Sheep(Agent):
pass

model = Model()
wolf = Wolf(1, model)
sheep = Sheep(2, model)

assert model.agents_by_type[Wolf] == AgentSet([wolf], model)
assert model.agents_by_type[Sheep] == AgentSet([sheep], model)
assert len(model.agents_by_type) == 2
2 changes: 1 addition & 1 deletion tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_random_activation_counts(self):
agent_types = model.agent_types
for agent_type in agent_types:
assert model.schedule.get_type_count(agent_type) == len(
model.get_agents_of_type(agent_type)
model.agents_by_type[agent_type]
)

# def test_add_non_unique_ids(self):
Expand Down

0 comments on commit 221084d

Please sign in to comment.