Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance DataCollector to validate model_reporters functions #2605

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,88 @@
for name, columns in tables.items():
self._new_table(name, columns)

def _validate_model_reporter(self, name, reporter, model):
"""Validate model reporter and issue warnings if necessary.

Args:
name: Name of the reporter
reporter: Reporter definition
model: Model instance
"""
# Type 1: Lambda function
if isinstance(reporter, types.LambdaType):
try:
# Try to call the lambda with a model instance
reporter(model)
except Exception as e:
warnings.warn(

Check warning on line 151 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L150-L151

Added lines #L150 - L151 were not covered by tests
f"Warning: Lambda reporter '{name}' failed: {e!s}\n"
f"Example of valid lambda: lambda m: len(m.agents)",
UserWarning,
stacklevel=2,
)
Comment on lines +152 to +156
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why issue a warning instead of an exception?

return

# Type 2: Method of class/instance
if callable(reporter) and not isinstance(reporter, types.LambdaType):
try:
# Try to call the method
reporter(model)
except Exception as e:
warnings.warn(
f"Warning: Method reporter '{name}' failed: {e!s}\n"
f"Example of valid method: self.get_agent_count or Model.get_agent_count",
UserWarning,
stacklevel=2,
)
return

# Type 3: Class attributes (string)
if isinstance(reporter, str):
if not hasattr(model, reporter):
warnings.warn(

Check warning on line 176 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L176

Added line #L176 was not covered by tests
f"Warning: Model reporter '{name}' references attribute '{reporter}' "
f"which is not defined in the model.\n"
f"Example of valid attribute: 'model_attribute'",
UserWarning,
stacklevel=2,
)
return

# Type 4: Function with parameters in list
if isinstance(reporter, list):
if not reporter or not callable(reporter[0]):
warnings.warn(

Check warning on line 188 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L188

Added line #L188 was not covered by tests
f"Warning: Invalid function list format for reporter '{name}'.\n"
f"First element must be a callable function.\n"
f"Example: [function, [param1, param2]]",
UserWarning,
stacklevel=2,
)
return

# If none of the above types match
warnings.warn(

Check warning on line 198 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L198

Added line #L198 was not covered by tests
f"Warning: Model reporter '{name}' has invalid type: {type(reporter)}.\n"
f"Must be one of:\n"
f"1. Lambda function: lambda m: len(m.agents)\n"
f"2. Method: self.get_count or Model.get_count\n"
f"3. Attribute name (str): 'model_attribute'\n"
f"4. Function list: [function, [param1, param2]]",
UserWarning,
stacklevel=2,
)

def _new_model_reporter(self, name, reporter):
"""Add a new model-level reporter to collect.

Args:
name: Name of the model-level variable to collect.
reporter: Attribute string, or function object that returns the
variable when given a model instance.
reporter: Can be one of four types:
1. Attribute name (str): "attribute_name"
2. Lambda function: lambda m: len(m.agents)
3. Method: model.get_count or Model.get_count
4. List of [function, [parameters]]
"""
self.model_reporters[name] = reporter
self.model_vars[name] = []
Expand Down Expand Up @@ -263,6 +338,9 @@
"""Collect all the data for the given model object."""
if self.model_reporters:
for var, reporter in self.model_reporters.items():
# Add validation
self._validate_model_reporter(var, reporter, model)

Comment on lines +341 to +343
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this imply that the validation is done every single time you try to collect the data? That seems inefficient and overkill.

# Check if lambda or partial function
if isinstance(reporter, types.LambdaType | partial):
# Use deepcopy to store a copy of the data,
Expand Down
Loading