Skip to content

Commit

Permalink
replace model with random in AgentSet init (#2350)
Browse files Browse the repository at this point in the history
* replace model with random in AgentSet `__init__`

also closing #2323
  • Loading branch information
quaquel authored Oct 13, 2024
1 parent 0082da2 commit e255a2d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 49 deletions.
31 changes: 13 additions & 18 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,18 @@ class AgentSet(MutableSet, Sequence):
which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet.
"""

def __init__(self, agents: Iterable[Agent], model: Model):
def __init__(self, agents: Iterable[Agent], random: Random | None = None):
"""Initializes the AgentSet with a collection of agents and a reference to the model.
Args:
agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
model (Model): The ABM model instance to which this AgentSet belongs.
random (Random): the random number generator
"""
self.model = model
if random is None:
random = (
Random()
) # FIXME see issue 1981, how to get the central rng from model
self.random = random
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})

def __len__(self) -> int:
Expand Down Expand Up @@ -177,7 +181,7 @@ def agent_generator(filter_func, agent_type, at_most):

agents = agent_generator(filter_func, agent_type, at_most)

return AgentSet(agents, self.model) if not inplace else self._update(agents)
return AgentSet(agents, self.random) if not inplace else self._update(agents)

def shuffle(self, inplace: bool = False) -> AgentSet:
"""Randomly shuffle the order of agents in the AgentSet.
Expand All @@ -200,7 +204,7 @@ def shuffle(self, inplace: bool = False) -> AgentSet:
return self
else:
return AgentSet(
(agent for ref in weakrefs if (agent := ref()) is not None), self.model
(agent for ref in weakrefs if (agent := ref()) is not None), self.random
)

def sort(
Expand All @@ -225,7 +229,7 @@ def sort(
sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending)

return (
AgentSet(sorted_agents, self.model)
AgentSet(sorted_agents, self.random)
if not inplace
else self._update(sorted_agents)
)
Expand Down Expand Up @@ -477,26 +481,17 @@ def __getstate__(self):
Returns:
dict: A dictionary representing the state of the AgentSet.
"""
return {"agents": list(self._agents.keys()), "model": self.model}
return {"agents": list(self._agents.keys()), "random": self.random}

def __setstate__(self, state):
"""Set the state of the AgentSet during deserialization.
Args:
state (dict): A dictionary representing the state to restore.
"""
self.model = state["model"]
self.random = state["random"]
self._update(state["agents"])

@property
def random(self) -> Random:
"""Provide access to the model's random number generator.
Returns:
Random: The random number generator associated with the model.
"""
return self.model.random

def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
"""Group agents by the specified attribute or return from the callable.
Expand Down Expand Up @@ -529,7 +524,7 @@ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:

if result_type == "agentset":
return GroupBy(
{k: AgentSet(v, model=self.model) for k, v in groups.items()}
{k: AgentSet(v, random=self.random) for k, v in groups.items()}
)
else:
return GroupBy(groups)
Expand Down
11 changes: 7 additions & 4 deletions mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None
self.running = True
self.steps: int = 0

self._setup_agent_registration()

self._seed = seed
if self._seed is None:
# We explicitly specify the seed here so that we know its value in
Expand All @@ -65,6 +63,9 @@ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None
self._user_step = self.step
self.step = self._wrapped_step

# setup agent registration data structures
self._setup_agent_registration()

def _wrapped_step(self, *args: Any, **kwargs: Any) -> None:
"""Automatically increments time and steps after calling the user's step method."""
# Automatically increment time and step counters
Expand Down Expand Up @@ -119,7 +120,9 @@ def _setup_agent_registration(self):
self._agents_by_type: dict[
type[Agent], AgentSet
] = {} # a dict with an agentset for each class of agents
self._all_agents = AgentSet([], self) # an agenset with all agents
self._all_agents = AgentSet(
[], random=self.random
) # an agenset with all agents

def register_agent(self, agent):
"""Register the agent with the model.
Expand Down Expand Up @@ -153,7 +156,7 @@ def register_agent(self, agent):
[
agent,
],
self,
random=self.random,
)

self._all_agents.add(agent)
Expand Down
54 changes: 27 additions & 27 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_agentset():
model = Model()
agents = [AgentTest(model) for _ in range(10)]

agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

assert agents[0] in agentset
assert len(agentset) == len(agents)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_function(agent):

# because AgentSet uses weakrefs, we need hard refs as well....
other_agents, another_set = pickle.loads( # noqa: S301
pickle.dumps([agents, AgentSet(agents, model)])
pickle.dumps([agents, AgentSet(agents, random=model.random)])
)
assert all(
a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents)
Expand All @@ -129,19 +129,19 @@ def test_function(agent):
def test_agentset_initialization():
"""Test agentset initialization."""
model = Model()
empty_agentset = AgentSet([], model)
empty_agentset = AgentSet([], random=model.random)
assert len(empty_agentset) == 0

agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
assert len(agentset) == 10


def test_agentset_serialization():
"""Test pickleability of agentset."""
model = Model()
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

serialized = pickle.dumps(agentset)
deserialized = pickle.loads(serialized) # noqa: S301
Expand All @@ -156,7 +156,7 @@ def test_agent_membership():
"""Test agent membership in AgentSet."""
model = Model()
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

assert agents[0] in agentset
assert AgentTest(model) not in agentset
Expand All @@ -166,7 +166,7 @@ def test_agent_add_remove_discard():
"""Test adding, removing and discarding agents from AgentSet."""
model = Model()
agent = AgentTest(model)
agentset = AgentSet([], model)
agentset = AgentSet([], random=model.random)

agentset.add(agent)
assert agent in agentset
Expand All @@ -186,7 +186,7 @@ def test_agentset_get_item():
"""Test integer based access to AgentSet."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

assert agentset[0] == agents[0]
assert agentset[-1] == agents[-1]
Expand All @@ -200,7 +200,7 @@ def test_agentset_do_str():
"""Test AgentSet.do with str."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")
Expand All @@ -213,7 +213,7 @@ def test_agentset_do_str():
n = 10
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -223,7 +223,7 @@ def test_agentset_do_str():
# setup
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -235,7 +235,7 @@ def test_agentset_do_callable():
"""Test AgentSet.do with callable."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

# Test callable with non-existent function
with pytest.raises(AttributeError):
Expand All @@ -249,7 +249,7 @@ def test_agentset_do_callable():
n = 10
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -260,7 +260,7 @@ def test_agentset_do_callable():
# setup again for lambda function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -278,7 +278,7 @@ def remove_function(agent):
# setup again for actual function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -289,7 +289,7 @@ def remove_function(agent):
# setup again for actual function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)
for agent in agents:
agent.agent_set = agentset

Expand Down Expand Up @@ -354,7 +354,7 @@ def test_agentset_agg():
agent.energy = i + 1
agent.wealth = 10 * (i + 1)

agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

# Test min aggregation
min_energy = agentset.agg("energy", min)
Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(self, model, age=None):

model = Model()
agents = [TestAgentWithAttribute(model, age=i) for i in range(5)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

# Set a new attribute "health" and an existing attribute "age" for all agents
agentset.set("health", 100).set("age", 50).set("status", "active")
Expand All @@ -410,7 +410,7 @@ def test_agentset_map_str():
"""Test AgentSet.map with strings."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")
Expand All @@ -423,7 +423,7 @@ def test_agentset_map_callable():
"""Test AgentSet.map with callable."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

# Test callable with non-existent function
with pytest.raises(AttributeError):
Expand All @@ -450,7 +450,7 @@ def test_method(self):
self.called = True

agents = [TestAgentShuffleDo(model) for _ in range(100)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

# Test shuffle_do with a string method name
agentset.shuffle_do("test_method")
Expand All @@ -477,7 +477,7 @@ def test_agentset_get_attribute():
"""Test AgentSet.get for attributes."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

unique_ids = agentset.get("unique_id")
assert unique_ids == [agent.unique_id for agent in agents]
Expand All @@ -491,7 +491,7 @@ def test_agentset_get_attribute():
agent = AgentTest(model)
agent.i = i**2
agents.append(agent)
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

values = agentset.get(["unique_id", "i"])

Expand Down Expand Up @@ -521,7 +521,7 @@ def test_agentset_select_by_type():

# Combine the two types of agents
mixed_agents = test_agents + other_agents
agentset = AgentSet(mixed_agents, model)
agentset = AgentSet(mixed_agents, random=model.random)

# Test selection by type
selected_test_agents = agentset.select(agent_type=AgentTest)
Expand All @@ -544,11 +544,11 @@ def test_agentset_shuffle():
model = Model()
test_agents = [AgentTest(model) for _ in range(12)]

agentset = AgentSet(test_agents, model=model)
agentset = AgentSet(test_agents, random=model.random)
agentset = agentset.shuffle()
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))

agentset = AgentSet(test_agents, model=model)
agentset = AgentSet(test_agents, random=model.random)
agentset.shuffle(inplace=True)
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))

Expand All @@ -567,7 +567,7 @@ def get_unique_identifier(self):

model = Model()
agents = [TestAgent(model) for _ in range(10)]
agentset = AgentSet(agents, model)
agentset = AgentSet(agents, random=model.random)

groups = agentset.groupby("even")
assert len(groups.groups[True]) == 5
Expand Down

0 comments on commit e255a2d

Please sign in to comment.