Skip to content

Commit

Permalink
BayDAG Contribution #2: Enhanced Disaggregate Accessibility Merging (#…
Browse files Browse the repository at this point in the history
…768)

* get all disaggregate accessibility values

* updated settings to work with Pydantic

* KEEP_COLS setting

* keep_cols update and returning tables
  • Loading branch information
dhensle authored Mar 28, 2024
1 parent bf074a5 commit 97e45b7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
37 changes: 37 additions & 0 deletions activitysim/abm/models/disaggregate_accessibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ class DisaggregateAccessibilitySettings(PydanticReadable, extra="forbid"):
procedure work.
"""

KEEP_COLS: list[str] | None = None
"""
Disaggreate accessibility table is grouped by the "by" cols above and the KEEP_COLS are averaged
across the group. Initializing the below as NA if not in the auto ownership level, they are skipped
in the groupby mean and the values are correct.
(It's a way to avoid having to update code to reshape the table and introduce new functionality there.)
If none, will keep all of the columns with "accessibility" in the name.
"""

FROM_TEMPLATES: bool = False
annotate_proto_tables: list[DisaggregateAccessibilityAnnotateSettings] = []
"""
Expand All @@ -164,6 +173,11 @@ class DisaggregateAccessibilitySettings(PydanticReadable, extra="forbid"):
"""
NEAREST_METHOD: str = "skims"

postprocess_proto_tables: list[DisaggregateAccessibilityAnnotateSettings] = []
"""
List of preprocessor settings to apply to the proto-population tables after generation.
"""


def read_disaggregate_accessibility_yaml(
state: workflow.State, file_name
Expand Down Expand Up @@ -846,6 +860,10 @@ def compute_disaggregate_accessibility(
state.tracing.register_traceable_table(tablename, df)
del df

disagg_model_settings = read_disaggregate_accessibility_yaml(
state, "disaggregate_accessibility.yaml"
)

# Run location choice
logsums = get_disaggregate_logsums(
state,
Expand Down Expand Up @@ -906,4 +924,23 @@ def compute_disaggregate_accessibility(
for k, df in logsums.items():
state.add_table(k, df)

# available post-processing
for annotations in disagg_model_settings.postprocess_proto_tables:
tablename = annotations.tablename
df = state.get_dataframe(tablename)
assert df is not None
assert annotations is not None
assign_columns(
state,
df=df,
model_settings={
**annotations.annotate.dict(),
**disagg_model_settings.suffixes.dict(),
},
trace_label=tracing.extend_trace_label(
"disaggregate_accessibility.postprocess", tablename
),
)
state.add_table(tablename, df)

return
30 changes: 15 additions & 15 deletions activitysim/abm/tables/disaggregate_accessibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def maz_centroids(state: workflow.State):


@workflow.table
def proto_disaggregate_accessibility(state: workflow.State):
def proto_disaggregate_accessibility(state: workflow.State) -> pd.DataFrame:
# Read existing accessibilities, but is not required to enable model compatibility
df = input.read_input_table(
state, "proto_disaggregate_accessibility", required=False
Expand All @@ -130,7 +130,7 @@ def proto_disaggregate_accessibility(state: workflow.State):


@workflow.table
def disaggregate_accessibility(state: workflow.State):
def disaggregate_accessibility(state: workflow.State) -> pd.DataFrame:
"""
This step initializes pre-computed disaggregate accessibility and merges it onto the full synthetic population.
Function adds merged all disaggregate accessibility tables to the pipeline but returns nothing.
Expand Down Expand Up @@ -169,17 +169,17 @@ def disaggregate_accessibility(state: workflow.State):
)
merging_params = model_settings.MERGE_ON
nearest_method = model_settings.NEAREST_METHOD
accessibility_cols = [
x for x in proto_accessibility_df.columns if "accessibility" in x
]

if model_settings.KEEP_COLS is None:
keep_cols = [x for x in proto_accessibility_df.columns if "accessibility" in x]
else:
keep_cols = model_settings.KEEP_COLS

# Parse the merging parameters
assert merging_params is not None

# Check if already assigned!
if set(accessibility_cols).intersection(persons_merged_df.columns) == set(
accessibility_cols
):
if set(keep_cols).intersection(persons_merged_df.columns) == set(keep_cols):
return

# Find the nearest zone (spatially) with accessibilities calculated
Expand Down Expand Up @@ -211,7 +211,7 @@ def disaggregate_accessibility(state: workflow.State):
# because it will get slightly different logsums for households in the same zone.
# This is because different destination zones were selected. To resolve, get mean by cols.
right_df = (
proto_accessibility_df.groupby(merge_cols)[accessibility_cols]
proto_accessibility_df.groupby(merge_cols)[keep_cols]
.mean()
.sort_values(nearest_cols)
.reset_index()
Expand Down Expand Up @@ -244,9 +244,9 @@ def disaggregate_accessibility(state: workflow.State):
)

# Predict the nearest person ID and pull the logsums
matched_logsums_df = right_df.loc[clf.predict(x_pop)][
accessibility_cols
].reset_index(drop=True)
matched_logsums_df = right_df.loc[clf.predict(x_pop)][keep_cols].reset_index(
drop=True
)
merge_df = pd.concat(
[left_df.reset_index(drop=False), matched_logsums_df], axis=1
).set_index("person_id")
Expand Down Expand Up @@ -278,9 +278,9 @@ def disaggregate_accessibility(state: workflow.State):

# Check that it was correctly left-joined
assert all(persons_merged_df[merge_cols] == merge_df[merge_cols])
assert any(merge_df[accessibility_cols].isnull())
assert any(merge_df[keep_cols].isnull())

# Inject merged accessibilities so that it can be included in persons_merged function
state.add_table("disaggregate_accessibility", merge_df[accessibility_cols])
state.add_table("disaggregate_accessibility", merge_df[keep_cols])

return merge_df[accessibility_cols]
return merge_df[keep_cols]

0 comments on commit 97e45b7

Please sign in to comment.