Skip to content

Commit

Permalink
feat: Implement exclude_none_values in DataCollector
Browse files Browse the repository at this point in the history
  • Loading branch information
rht authored and tpike3 committed May 27, 2023
1 parent 6f08b07 commit f78f80f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
27 changes: 26 additions & 1 deletion mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ class DataCollector:
one and stores the results.
"""

def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
def __init__(
self,
model_reporters=None,
agent_reporters=None,
tables=None,
exclude_none_values=False,
):
"""Instantiate a DataCollector with lists of model and agent reporters.
Both model_reporters and agent_reporters accept a dictionary mapping a
variable name to either an attribute name, or a method.
Expand All @@ -74,6 +80,8 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
model_reporters: Dictionary of reporter names and attributes/funcs
agent_reporters: Dictionary of reporter names and attributes/funcs.
tables: Dictionary of table names to lists of column names.
exclude_none_values: Boolean of whether to drop records which values
are None, in the final result.
Notes:
If you want to pickle your model you must not use lambda functions.
Expand All @@ -97,6 +105,7 @@ class attributes of a model
self.model_vars = {}
self._agent_records = {}
self.tables = {}
self.exclude_none_values = exclude_none_values

if model_reporters is not None:
for name, reporter in model_reporters.items():
Expand Down Expand Up @@ -151,7 +160,23 @@ def _new_table(self, table_name, table_columns):
def _record_agents(self, model):
"""Record agents data in a mapping of functions and agents."""
rep_funcs = self.agent_reporters.values()
if self.exclude_none_values:
# Drop records which values are None.

def get_reports(agent):
_prefix = (agent.model.schedule.steps, agent.unique_id)
reports = (rep(agent) for rep in rep_funcs)
reports_without_none = tuple(r for r in reports if r is not None)
if len(reports_without_none) == 0:
return None
return _prefix + reports_without_none

agent_records = (get_reports(agent) for agent in model.schedule.agents)
agent_records_without_none = (r for r in agent_records if r is not None)
return agent_records_without_none

if all(hasattr(rep, "attribute_name") for rep in rep_funcs):
# This branch is for performance optimization purpose.
prefix = ["model.schedule.steps", "unique_id"]
attributes = [func.attribute_name for func in rep_funcs]
get_reports = attrgetter(*prefix + attributes)
Expand Down
7 changes: 6 additions & 1 deletion mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def reset_randomizer(self, seed: int | None = None) -> None:
self._seed = seed

def initialize_data_collector(
self, model_reporters=None, agent_reporters=None, tables=None
self,
model_reporters=None,
agent_reporters=None,
tables=None,
exclude_none_values=False,
) -> None:
if not hasattr(self, "schedule") or self.schedule is None:
raise RuntimeError(
Expand All @@ -80,6 +84,7 @@ def initialize_data_collector(
model_reporters=model_reporters,
agent_reporters=agent_reporters,
tables=tables,
exclude_none_values=exclude_none_values,
)
# Collect data for the first time during initialization.
self.datacollector.collect(self)
65 changes: 60 additions & 5 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,35 @@ def write_final_values(self):
self.model.datacollector.add_table_row("Final_Values", row)


class DifferentMockAgent(MockAgent):
# We define a different MockAgent to test for attributes that are present
# only in 1 type of agent, but not the other.
def __init__(self, unique_id, model, val=0):
super().__init__(unique_id, model, val=val)
self.val3 = val + 42


class MockModel(Model):
"""
Minimalistic model for testing purposes.
"""

schedule = BaseScheduler(None)

def __init__(self):
def __init__(self, test_exclude_none_values=False):
self.schedule = BaseScheduler(self)
self.model_val = 100

for i in range(10):
a = MockAgent(i, self, val=i)
self.schedule.add(a)
self.n = 10
for i in range(self.n):
self.schedule.add(MockAgent(i, self, val=i))
if test_exclude_none_values:
self.schedule.add(DifferentMockAgent(self.n + i, self, val=i))
if test_exclude_none_values:
# Only DifferentMockAgent has val3.
agent_reporters = {"value": lambda a: a.val, "value3": "val3"}
else:
agent_reporters = {"value": lambda a: a.val, "value2": "val2"}
self.initialize_data_collector(
{
"total_agents": lambda m: m.schedule.get_agent_count(),
Expand All @@ -54,8 +69,9 @@ def __init__(self):
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
{"value": lambda a: a.val, "value2": "val2"},
agent_reporters,
{"Final_Values": ["agent_id", "final_value"]},
exclude_none_values=test_exclude_none_values,
)

def test_model_calc_comp(self, input1, input2):
Expand Down Expand Up @@ -195,5 +211,44 @@ def test_initialize_before_agents_added_to_scheduler(self):
)


class TestDataCollectorExcludeNone(unittest.TestCase):
def setUp(self):
"""
Create the model and run it a set number of steps.
"""
self.model = MockModel(test_exclude_none_values=True)
for i in range(7):
if i == 4:
self.model.schedule.remove(self.model.schedule._agents[3])
self.model.step()

def test_agent_records(self):
"""
Test agent-level variable collection.
"""
data_collector = self.model.datacollector
agent_table = data_collector.get_agent_vars_dataframe()

assert len(data_collector._agent_records) == 8
for step, records in data_collector._agent_records.items():
if step < 5:
assert len(records) == 20
else:
assert len(records) == 19

for values in records:
agent_id = values[1]
if agent_id < self.model.n:
assert len(values) == 3
else:
# Agents with agent_id >= self.model.n are
# DifferentMockAgent, which additionally contains val3.
assert len(values) == 4

assert "value" in list(agent_table.columns)
assert "value2" not in list(agent_table.columns)
assert "value3" in list(agent_table.columns)


if __name__ == "__main__":
unittest.main()

0 comments on commit f78f80f

Please sign in to comment.