Skip to content

Commit

Permalink
progress on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed Nov 10, 2023
1 parent 7985f2c commit 5e90e65
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 120 deletions.
14 changes: 3 additions & 11 deletions activitysim/abm/models/atwork_subtour_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd

from activitysim.abm.models.util.vectorize_tour_scheduling import (
TourSchedulingSettings,
vectorize_subtour_scheduling,
)
from activitysim.core import config, estimation, expressions, simulate
Expand All @@ -23,23 +24,14 @@
DUMP = False


class AtworkSubtourSchedulingSettings(PydanticReadable):
"""
Settings for the `atwork_subtour_scheduling` component.
"""

sharrow_skip: bool = True
"""Skip Sharow""" # TODO Check this again


@workflow.step
def atwork_subtour_scheduling(
state: workflow.State,
tours: pd.DataFrame,
persons_merged: pd.DataFrame,
tdd_alts: pd.DataFrame,
skim_dict: SkimDict | SkimDataset,
model_settings: AtworkSubtourSchedulingSettings | None = None,
model_settings: TourSchedulingSettings | None = None,
model_settings_file_name: str = "tour_scheduling_atwork.yaml",
trace_label: str = "atwork_subtour_scheduling",
) -> None:
Expand All @@ -56,7 +48,7 @@ def atwork_subtour_scheduling(
return

if model_settings is None:
model_settings = AtworkSubtourSchedulingSettings.read_settings_file(
model_settings = TourSchedulingSettings.read_settings_file(
state.filesystem,
model_settings_file_name,
)
Expand Down
26 changes: 14 additions & 12 deletions activitysim/abm/models/joint_tour_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from activitysim.abm.models.util.vectorize_tour_scheduling import (
TourSchedulingSettings,
vectorize_joint_tour_scheduling,
)
from activitysim.core import (
Expand All @@ -24,16 +25,17 @@
logger = logging.getLogger(__name__)


class JointTourSchedulingSettings(LogitComponentSettings, extra="forbid"):
"""
Settings for the `joint_tour_scheduling` component.
"""

preprocessor: PreprocessorSettings | None = None
"""Setting for the preprocessor."""

sharrow_skip: bool = False
"""Setting to skip sharrow"""
# class JointTourSchedulingSettings(LogitComponentSettings, extra="forbid"):
# """
# Settings for the `joint_tour_scheduling` component.
# """
#
# preprocessor: PreprocessorSettings | None = None
# """Setting for the preprocessor."""
#
# sharrow_skip: bool = False
# """Setting to skip sharrow"""
#


@workflow.step
Expand All @@ -42,7 +44,7 @@ def joint_tour_scheduling(
tours: pd.DataFrame,
persons_merged: pd.DataFrame,
tdd_alts: pd.DataFrame,
model_settings: JointTourSchedulingSettings | None = None,
model_settings: TourSchedulingSettings | None = None,
model_settings_file_name: str = "joint_tour_scheduling.yaml",
trace_label: str = "joint_tour_scheduling",
) -> None:
Expand All @@ -51,7 +53,7 @@ def joint_tour_scheduling(
"""

if model_settings is None:
model_settings = JointTourSchedulingSettings.read_settings_file(
model_settings = TourSchedulingSettings.read_settings_file(
state.filesystem,
model_settings_file_name,
)
Expand Down
46 changes: 34 additions & 12 deletions activitysim/abm/models/stop_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from __future__ import annotations

import logging
from typing import Any
from pathlib import Path
from typing import Any, Literal

import pandas as pd

Expand All @@ -24,20 +25,44 @@
logger = logging.getLogger(__name__)


class StopFrequencySpecSegmentSettings(LogitComponentSettings, extra="allow"):
# this class specifically allows "extra" settings because ActivitySim
# is set up to have the name of the segment column be identified with
# an arbitrary key.
SPEC: Path
COEFFICIENTS: Path


class StopFrequencySettings(LogitComponentSettings, extra="forbid"):
"""
Settings for the `free_parking` component.
Settings for the stop frequency component.
"""

LOGIT_TYPE: Literal["MNL"] = "MNL"
"""Logit model mathematical form.
* "MNL"
Multinomial logit model.
"""

preprocessor: PreprocessorSettings | None = None
"""Setting for the preprocessor."""

SPEC_SEGMENTS: dict[str, Any] = {}
# TODO Check this again
SPEC_SEGMENTS: list[StopFrequencySpecSegmentSettings] = {}

SPEC: Path | None = None
"""Utility specification filename.
This is sometimes alternatively called the utility expressions calculator
(UEC). It is a CSV file giving all the functions for the terms of a
linear-in-parameters utility expression. If SPEC_SEGMENTS is given, then
this unsegmented SPEC should be omitted.
"""

SEGMENT_COL: str = "primary_purpose"

# CONSTANTS TODO Check this again
CONSTANTS: dict[str, Any] = {}
"""Named constants usable in the utility expressions."""


@workflow.step
Expand Down Expand Up @@ -136,8 +161,7 @@ def stop_frequency(

choices_list = []
for segment_settings in spec_segments:
segment_name = segment_settings[segment_col]
segment_value = segment_settings[segment_col]
segment_name = segment_value = getattr(segment_settings, segment_col)

chooser_segment = tours_merged[tours_merged[segment_col] == segment_value]

Expand All @@ -153,16 +177,14 @@ def stop_frequency(
state, model_name=segment_name, bundle_name="stop_frequency"
)

segment_spec = state.filesystem.read_model_spec(
file_name=segment_settings["SPEC"]
)
segment_spec = state.filesystem.read_model_spec(file_name=segment_settings.SPEC)
assert segment_spec is not None, (
"spec for segment_type %s not found" % segment_name
)

coefficients_file_name = segment_settings["COEFFICIENTS"]
coefficients_file_name = segment_settings.COEFFICIENTS
coefficients_df = state.filesystem.read_model_coefficients(
file_name=coefficients_file_name
file_name=str(coefficients_file_name)
)
segment_spec = simulate.eval_coefficients(
state, segment_spec, coefficients_df, estimator
Expand Down
71 changes: 37 additions & 34 deletions activitysim/abm/models/trip_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import warnings
from builtins import range
from typing import List, Literal

import numpy as np
import pandas as pd
Expand All @@ -15,7 +16,7 @@
)
from activitysim.abm.models.util.trip import cleanup_failed_trips, failed_trip_cohorts
from activitysim.core import chunk, config, estimation, expressions, tracing, workflow
from activitysim.core.configuration.base import PydanticReadable
from activitysim.core.configuration.base import PreprocessorSettings, PydanticReadable
from activitysim.core.util import reindex

logger = logging.getLogger(__name__)
Expand All @@ -41,18 +42,18 @@
DEPARTURE_MODE = "departure"
DURATION_MODE = "stop_duration"
RELATIVE_MODE = "relative"
PROBS_JOIN_COLUMNS_DEPARTURE_BASED = [
PROBS_JOIN_COLUMNS_DEPARTURE_BASED: list[str] = [
"primary_purpose",
"outbound",
"tour_hour",
"trip_num",
]
PROBS_JOIN_COLUMNS_DURATION_BASED = ["outbound", "stop_num"]
PROBS_JOIN_COLUMNS_RELATIVE_BASED = ["outbound", "periods_left"]
PROBS_JOIN_COLUMNS_DURATION_BASED: list[str] = ["outbound", "stop_num"]
PROBS_JOIN_COLUMNS_RELATIVE_BASED: list[str] = ["outbound", "periods_left"]


def _logic_version(model_settings):
logic_version = model_settings.get("logic_version", None)
def _logic_version(model_settings: TripSchedulingSettings):
logic_version = model_settings.logic_version
if logic_version is None:
warnings.warn(
"The trip_scheduling component now has a logic_version setting "
Expand Down Expand Up @@ -196,7 +197,7 @@ def schedule_trips_in_leg(
outbound,
trips,
probs_spec,
model_settings,
model_settings: TripSchedulingSettings,
is_last_iteration,
trace_label,
*,
Expand All @@ -220,29 +221,25 @@ def schedule_trips_in_leg(
depart choice for trips, indexed by trip_id
"""

failfix = model_settings.get(FAILFIX, FAILFIX_DEFAULT)
depart_alt_base = model_settings.get("DEPART_ALT_BASE", 0)
scheduling_mode = model_settings.get("scheduling_mode", "departure")
preprocessor_settings = model_settings.get("preprocessor", None)

if scheduling_mode == "departure":
probs_join_cols = model_settings.get(
"probs_join_cols", PROBS_JOIN_COLUMNS_DEPARTURE_BASED
)
elif scheduling_mode == "stop_duration":
probs_join_cols = model_settings.get(
"probs_join_cols", PROBS_JOIN_COLUMNS_DURATION_BASED
)
elif scheduling_mode == "relative":
probs_join_cols = model_settings.get(
"probs_join_cols", PROBS_JOIN_COLUMNS_RELATIVE_BASED
)
else:
logger.error(
"Invalid scheduling mode specified: {0}.".format(scheduling_mode),
"Please select one of ['departure', 'stop_duration', 'relative'] and try again.",
)
raise ValueError(f"Invalid scheduling mode specified: {scheduling_mode}")
failfix = model_settings.FAILFIX
depart_alt_base = model_settings.DEPART_ALT_BASE
scheduling_mode = model_settings.scheduling_mode
preprocessor_settings = model_settings.preprocessor

probs_join_cols = model_settings.probs_join_cols
if probs_join_cols is None:
if scheduling_mode == "departure":
probs_join_cols = PROBS_JOIN_COLUMNS_DEPARTURE_BASED
elif scheduling_mode == "stop_duration":
probs_join_cols = PROBS_JOIN_COLUMNS_DURATION_BASED
elif scheduling_mode == "relative":
probs_join_cols = PROBS_JOIN_COLUMNS_RELATIVE_BASED
else:
logger.error(
"Invalid scheduling mode specified: {0}.".format(scheduling_mode),
"Please select one of ['departure', 'stop_duration', 'relative'] and try again.",
)
raise ValueError(f"Invalid scheduling mode specified: {scheduling_mode}")

# logger.debug("%s scheduling %s trips" % (trace_label, trips.shape[0]))

Expand Down Expand Up @@ -451,6 +448,14 @@ class TripSchedulingSettings(PydanticReadable):
"""Integer to add to probs column index to get time period it represents.
e.g. depart_alt_base = 5 means first column (column 0) represents 5 am"""

scheduling_mode: Literal["departure", "stop_duration", "relative"] = "departure"

probs_join_cols: list[str] | None = None

preprocessor: PreprocessorSettings | None = None

logic_version: int | None = None


@workflow.step(copy_tables=False)
def trip_scheduling(
Expand Down Expand Up @@ -560,7 +565,7 @@ def trip_scheduling(
pd.Series(list(range(len(tours))), tours.index), trips_df.tour_id
)

assert "DEPART_ALT_BASE" in model_settings
assert model_settings.DEPART_ALT_BASE
failfix = model_settings.FAILFIX

max_iterations = model_settings.MAX_ITERATIONS
Expand Down Expand Up @@ -609,9 +614,7 @@ def trip_scheduling(
failed = choices.reindex(trips_chunk.index).isnull()
logger.info("%s %s failed", trace_label_i, failed.sum())

if (failed.sum() > 0) & (
model_settings.get("scheduling_mode") == "relative"
):
if (failed.sum() > 0) & (model_settings.scheduling_mode == "relative"):
raise RuntimeError("failed trips with relative scheduling mode")

if not is_last_iteration:
Expand Down
Loading

0 comments on commit 5e90e65

Please sign in to comment.