Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add select_agents method to Model #1911

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import defaultdict

# mypy
from typing import Any
from typing import Any, Callable

from mesa.datacollection import DataCollector

Expand Down Expand Up @@ -106,3 +106,76 @@ def initialize_data_collector(
)
# Collect data for the first time during initialization.
self.datacollector.collect(self)

def select_agents(
self,
n: int | None = None,
sort: list[str] | None = None,
direction: list[str] | None = None,
filter_func: Callable[[Any], bool] | None = None,
agent_type: type[Any] | list[type[Any]] | None = None,
up_to: bool = True,
) -> list[Any]:
"""
Select agents based on various criteria including type, attributes, and custom filters.

Args:
n: Number of agents to select.
sort: Attributes to sort by.
direction: Sort direction for each attribute in `sort`.
filter_func: A callable to further filter agents.
agent_type: Type(s) of agents to include.
up_to: If True, allows returning up to `n` agents.

Returns:
A list of selected agents.
"""

# If agent_type is specified, fetch only those agents; otherwise, fetch all
if agent_type:
if not isinstance(agent_type, list):
agent_type = [agent_type]
agent_type_set = set(agent_type)
agents_iter = (
agent
for type_key, agents in self.agents.items()
if type_key in agent_type_set
for agent in agents
)
else:
agents_iter = (agent for agents in self.agents.values() for agent in agents)

# Apply functional filter if provided
if filter_func:
agents_iter = filter(filter_func, agents_iter)

# Convert to list if sorting is needed or n is specified
if sort and direction or n is not None:
agents_iter = list(agents_iter)
# If n is specified, shuffle the agents to avoid bias
self.random.shuffle(agents_iter)

# If only a specific number of agents is needed without sorting, limit early
if n is not None and not (sort and direction):
agents_iter = (
agents_iter[: min(n, len(agents_iter))] if up_to else agents_iter[:n]
)

# Sort agents if needed
if sort and direction:

def sort_key(agent):
return tuple(
getattr(agent, attr)
if dir.lower() == "lowest"
else -getattr(agent, attr)
for attr, dir in zip(sort, direction)
)

agents_iter.sort(key=sort_key)

# Select the desired number of agents after sorting
if n is not None and sort and direction:
return agents_iter[: min(n, len(agents_iter))] if up_to else agents_iter[:n]

return list(agents_iter)
106 changes: 106 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from mesa.agent import Agent
from mesa.model import Model

Expand Down Expand Up @@ -51,3 +53,107 @@ class TestAgent(Agent):
test_agent = TestAgent(model.next_id(), model)
assert test_agent in model.agents[type(test_agent)]
assert type(test_agent) in model.agent_types


class TestSelectAgents:
class MockAgent(Agent):
def __init__(self, unique_id, model, type_id, age, wealth):
super().__init__(unique_id, model)
self.type_id = type_id
self.age = age
self.wealth = wealth

@pytest.fixture
def model_with_agents(self):
model = Model()
for i in range(20):
self.MockAgent(i, model, type_id=i % 2, age=i + 20, wealth=100 - i * 2)
return model

def test_basic_selection(self, model_with_agents):
selected_agents = model_with_agents.select_agents()
assert len(selected_agents) == 20

def test_selection_with_n(self, model_with_agents):
selected_agents = model_with_agents.select_agents(n=5)
assert len(selected_agents) == 5

def test_sorting_and_direction(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
n=3, sort=["wealth"], direction=["highest"]
)
assert [agent.wealth for agent in selected_agents] == [100, 98, 96]

def test_functional_filtering(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
filter_func=lambda agent: agent.age > 30
)
assert all(agent.age > 30 for agent in selected_agents)

def test_type_filtering(self, model_with_agents):
selected_agents = model_with_agents.select_agents(agent_type=self.MockAgent)
assert all(isinstance(agent, self.MockAgent) for agent in selected_agents)

def test_up_to_flag(self, model_with_agents):
selected_agents = model_with_agents.select_agents(n=50, up_to=True)
assert len(selected_agents) == 20

def test_edge_case_empty_model(self):
empty_model = Model()
selected_agents = empty_model.select_agents()
assert len(selected_agents) == 0

def test_error_handling_invalid_sort(self, model_with_agents):
with pytest.raises(AttributeError):
model_with_agents.select_agents(
n=3, sort=["nonexistent_attribute"], direction=["highest"]
)

def test_sorting_with_multiple_criteria(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
n=3, sort=["type_id", "age"], direction=["lowest", "highest"]
)
assert [(agent.type_id, agent.age) for agent in selected_agents] == [
(0, 38),
(0, 36),
(0, 34),
]

def test_direction_with_multiple_criteria(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
n=3, sort=["type_id", "wealth"], direction=["highest", "lowest"]
)
assert [(agent.type_id, agent.wealth) for agent in selected_agents] == [
(1, 62),
(1, 66),
(1, 70),
]

def test_type_filtering_with_multiple_types(self, model_with_agents):
class AnotherMockAgent(Agent):
pass

# Adding different type agents to the model
for i in range(20, 25):
AnotherMockAgent(i, model_with_agents)

selected_agents = model_with_agents.select_agents(
agent_type=[self.MockAgent, AnotherMockAgent]
)
assert len(selected_agents) == 25

def test_selection_when_n_exceeds_agent_count(self, model_with_agents):
selected_agents = model_with_agents.select_agents(n=50)
assert len(selected_agents) == 20

def test_inverse_functional_filtering(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
filter_func=lambda agent: agent.age < 25
)
assert all(agent.age < 25 for agent in selected_agents)

def test_complex_lambda_in_filter(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
filter_func=lambda agent: agent.age > 25 and agent.wealth > 70
)
assert all(agent.age > 25 and agent.wealth > 70 for agent in selected_agents)