Skip to content

Commit

Permalink
Add an option to collect data on every adjustment round
Browse files Browse the repository at this point in the history
  • Loading branch information
rlskoeser committed Jul 30, 2024
1 parent 46394b4 commit 1bd8d22
Showing 1 changed file with 70 additions and 18 deletions.
88 changes: 70 additions & 18 deletions simulatingrisk/hawkdovemulti/batch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
},
# specific scenarios to allow paired statistical tests
"risk_adjust": {
# ary risk adjustment
# any risk adjustment
"risk_adjustment": ["adopt", "average"],
"risk_distribution": "uniform",
# use model defaults; grid size must be specified
Expand All @@ -60,7 +60,7 @@

# method for multiproc running model with a set of params
def run_hawkdovemulti_model(args):
run_id, iteration, params, max_steps = args
run_id, iteration, params, max_steps, data_collection_period = args
# simplified model runner adapted from mesa batch run code

model = HawkDoveMultipleRiskModel(**params)
Expand All @@ -74,28 +74,71 @@ def run_hawkdovemulti_model(args):
# and finish data collection to report on whatever was completed
break

# collect data for the last step
# by default, collect data for the last step
# (scheduler is 1-based index but data collection is 0-based)
step = model.schedule.steps - 1

model_data, all_agents_data = _collect_data(model, step)

# combine run id, step, and params, with collected model data
run_data = {"RunId": run_id, "iteration": iteration, "Step": step}
if data_collection_period == "end":
collect_steps = [model.schedule.steps - 1]
elif data_collection_period == "adjustment_round":
# when requested, collect data at every adjustment round
every_n = params.get("adjust_every", 10)
collect_steps = range(0, max_steps, every_n)

# make a dict of run id and params for combination with model data
run_data = {"RunId": run_id, "iteration": iteration, "Step": "-"}
run_data.update(params)
run_data.update(model_data)
all_model_data = []
all_agent_data = []

# collect data at the specified data collection points
for step in collect_steps:
try:
model_data, agent_data = _collect_data(model, step)
# preserve order: run, iteration, step, params first
# then data collection from model
model_run_data = run_data.copy()
model_run_data["Step"] = step
model_run_data.update(model_data)
all_model_data.append(model_run_data)

# add step to every agent data entry
agent_data = [
{
"Step": step,
**agent_data,
}
for agent_data in agent_data
]
all_agent_data.extend(agent_data)
except IndexError:
# if we requested a step that isn't available, collect last round
# (should capture converged status)
model_data, agent_data = _collect_data(model, -1)
model_run_data = run_data.copy()
model_run_data["Step"] = step
model_run_data.update(model_data)
all_model_data.append(model_run_data)
# add step to every agent data entry
agent_data = [
{
"Step": step,
**agent_data,
}
for agent_data in agent_data
]
all_agent_data.extend(agent_data)
break

agent_data = [
# populate run id and iteration for every row of agent data
all_agent_data = [
{
"RunId": run_id,
"iteration": iteration,
"Step": step,
**agent_data,
}
for agent_data in all_agents_data
for agent_data in all_agent_data
]

return run_data, agent_data
return all_model_data, all_agent_data


def batch_run(
Expand All @@ -108,6 +151,7 @@ def batch_run(
file_prefix,
max_runs,
param_choice,
data_collection_period,
):
run_params = params.get(param_choice)

Expand All @@ -125,7 +169,9 @@ def batch_run(
run_id = 0
for params in param_combinations:
for iteration in range(iterations):
runs_list.append((run_id, iteration, params, max_steps))
runs_list.append(
(run_id, iteration, params, max_steps, data_collection_period)
)
run_id += 1

# if maximum runs is specified, truncate the list of run arguments
Expand Down Expand Up @@ -165,11 +211,11 @@ def batch_run(
if model_dict_writer is None:
# get field names from first entry
model_dict_writer = csv.DictWriter(
model_output_file, model_data.keys()
model_output_file, model_data[0].keys()
)
model_dict_writer.writeheader()

model_dict_writer.writerow(model_data)
model_dict_writer.writerows(model_data)

if collect_agent_data:
if agent_dict_writer is None:
Expand Down Expand Up @@ -247,7 +293,12 @@ def main():
choices=params.keys(),
default="default",
)

parser.add_argument(
"--collect-data",
help="When and how often to collect model and agent data",
choices=["end", "adjustment_round"],
default="end",
)
args = parser.parse_args()
batch_run(
params,
Expand All @@ -259,6 +310,7 @@ def main():
args.file_prefix,
args.max_runs,
args.params,
args.collect_data,
)


Expand Down

0 comments on commit 1bd8d22

Please sign in to comment.