From de58ddb413dfa8eee277a90eed0713c1b7211efb Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Wed, 28 Aug 2024 19:26:24 +0200 Subject: [PATCH] AgentSet: Add `set` method (#2254) Adds a `set` method to the `AgentSet` class that allows setting a specified attribute of all agents within the set to a given value. This method provides a streamlined way to update agent attributes in bulk, without having to resort to `lambda` or other callable functions. ### Motive The motivation behind this feature is to simplify and expedite the process of modifying attributes across all agents in an `AgentSet`. Previously, this required iterating manually over the agents or using a `lambda` or other callable function. This new method makes the operation more elegant. ### Implementation The `set` method was added to the `AgentSet` class. It iterates through all agents in the set and applies the `setattr` function to assign the specified value to the desired attribute. The method operates in place, modifying the existing agents directly. The (now updated) AgentSet itself is returned, which allows for chaining. You can both set existing Agent variables or set new ones. --- mesa/agent.py | 15 +++++++++++++++ tests/test_agent.py | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/mesa/agent.py b/mesa/agent.py index f137ba09c2c..39abb77d879 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -316,6 +316,21 @@ def get(self, attr_names: str | list[str]) -> list[Any]: for agent in self._agents ] + def set(self, attr_name: str, value: Any) -> AgentSet: + """ + Set a specified attribute to a given value for all agents in the AgentSet. + + Args: + attr_name (str): The name of the attribute to set. + value (Any): The value to set the attribute to. + + Returns: + AgentSet: The AgentSet instance itself, after setting the attribute. + """ + for agent in self: + setattr(agent, attr_name, value) + return self + def __getitem__(self, item: int | slice) -> Agent: """ Retrieve an agent or a slice of agents from the AgentSet. diff --git a/tests/test_agent.py b/tests/test_agent.py index 516902f969b..8c5bf35159b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -270,6 +270,30 @@ def remove_function(agent): assert len(agentset) == 0 +def test_agentset_set_method(): + # Initialize the model and agents with and without existing attributes + class TestAgentWithAttribute(Agent): + def __init__(self, unique_id, model, age=None): + super().__init__(unique_id, model) + self.age = age + + model = Model() + agents = [TestAgentWithAttribute(model.next_id(), model, age=i) for i in range(5)] + agentset = AgentSet(agents, model) + + # Set a new attribute "health" and an existing attribute "age" for all agents + agentset.set("health", 100).set("age", 50).set("status", "active") + + # Check if all agents have the "health", "age", and "status" attributes correctly set + for agent in agentset: + assert hasattr(agent, "health") + assert agent.health == 100 + assert hasattr(agent, "age") + assert agent.age == 50 + assert hasattr(agent, "status") + assert agent.status == "active" + + def test_agentset_map_str(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)]