Skip to content

Commit

Permalink
Add remove_alts option for model groups
Browse files Browse the repository at this point in the history
This controls whether alternatives are removed from the alternatives
pool between doing prediction for different segments.
When doing LCMs (e.g. probability_mode='single_chooser'
and choice_mode='aggregate') this should be set to True.
For doing something like automobile ownership this should be False.
False is the default.
  • Loading branch information
jiffyclub committed Feb 24, 2015
1 parent e53a066 commit 282a59a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
31 changes: 27 additions & 4 deletions urbansim/models/dcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,13 +776,22 @@ class MNLDiscreteChoiceModelGroup(DiscreteChoiceModel):
Parameters
----------
segmentation_col
segmentation_col : str
Name of a column in the table of choosers. Will be used to perform
a pandas groupby on the choosers table.
remove_alts : bool, optional
Specify how to handle alternatives between prediction for different
models. If False, the alternatives table is not modified between
predictions. If True, alternatives that have been chosen
are removed from the alternatives table before doing another
round of prediction.
name : str, optional
A name that may be used in places to identify this group.
"""
def __init__(self, segmentation_col, name=None):
def __init__(self, segmentation_col, remove_alts=False, name=None):
self.segmentation_col = segmentation_col
self.remove_alts = remove_alts
self.name = name if name is not None else 'MNLDiscreteChoiceModelGroup'
self.models = {}

Expand Down Expand Up @@ -1088,6 +1097,9 @@ def predict(self, choosers, alternatives, debug=False):

for name, df in self._iter_groups(choosers):
choices = self.models[name].predict(df, alternatives, debug=debug)
if self.remove_alts:
alternatives = alternatives.loc[
~alternatives.index.isin(choices)]
results.append(choices)

logger.debug(
Expand Down Expand Up @@ -1177,6 +1189,12 @@ class SegmentedMNLDiscreteChoiceModel(DiscreteChoiceModel):
the alternatives index is used.
default_model_expr : str, iterable, or dict, optional
A patsy model expression. Should contain only a right-hand side.
remove_alts : bool, optional
Specify how to handle alternatives between prediction for different
models. If False, the alternatives table is not modified between
predictions. If True, alternatives that have been chosen
are removed from the alternatives table before doing another
round of prediction.
name : str, optional
An optional string used to identify the model in places.
Expand All @@ -1188,7 +1206,8 @@ def __init__(
alts_fit_filters=None, alts_predict_filters=None,
interaction_predict_filters=None,
estimation_sample_size=None,
choice_column=None, default_model_expr=None, name=None):
choice_column=None, default_model_expr=None, remove_alts=False,
name=None):
self.segmentation_col = segmentation_col
self.sample_size = sample_size
self.probability_mode = probability_mode
Expand All @@ -1201,7 +1220,9 @@ def __init__(
self.estimation_sample_size = estimation_sample_size
self.choice_column = choice_column
self.default_model_expr = default_model_expr
self._group = MNLDiscreteChoiceModelGroup(segmentation_col)
self.remove_alts = remove_alts
self._group = MNLDiscreteChoiceModelGroup(
segmentation_col, remove_alts=remove_alts)
self.name = (name if name is not None else
'SegmentedMNLDiscreteChoiceModel')

Expand Down Expand Up @@ -1240,6 +1261,7 @@ def from_yaml(cls, yaml_str=None, str_or_buffer=None):
cfg['estimation_sample_size'],
cfg['choice_column'],
default_model_expr,
cfg['remove_alts'],
cfg['name'])

if "models" not in cfg:
Expand Down Expand Up @@ -1565,6 +1587,7 @@ def to_dict(self):
'default_config': {
'model_expression': self.default_model_expr,
},
'remove_alts': self.remove_alts,
'fitted': self.fitted,
'models': {
yamlio.to_scalar_safe(name):
Expand Down
37 changes: 31 additions & 6 deletions urbansim/models/tests/test_dcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_mnl_dcm_choice_mode_agg(seed, basic_dcm_fit, choosers, alternatives):
pd.Series(['f', 'a', 'd', 'c'], index=[0, 1, 3, 4]))


def test_mnl_dcm_group(grouped_choosers, alternatives):
def test_mnl_dcm_group(seed, grouped_choosers, alternatives):
model_exp = 'var2 + var1:var3'
sample_size = 4
choosers_predict_filters = ['var1 != 7']
Expand Down Expand Up @@ -296,10 +296,22 @@ def test_mnl_dcm_group(grouped_choosers, alternatives):
sprobs = group.summed_probabilities(grouped_choosers, alternatives)
assert len(sprobs) == len(filtered_alts)

choice_state = np.random.get_state()
choices = group.predict(grouped_choosers, alternatives)

assert len(choices) == len(filtered_choosers)
assert choices.isin(alternatives.index).all()
pdt.assert_series_equal(
choices,
pd.Series(['c', 'a', 'a', 'g'], index=[0, 3, 1, 4]))

# check that we don't get the same alt twice if they are removed
# make sure we're starting from the same random state as the last draw
np.random.set_state(choice_state)
group.remove_alts = True
choices = group.predict(grouped_choosers, alternatives)

pdt.assert_series_equal(
choices,
pd.Series(['c', 'a', 'b', 'g'], index=[0, 3, 1, 4]))


def test_mnl_dcm_segmented_raises():
Expand All @@ -309,7 +321,7 @@ def test_mnl_dcm_segmented_raises():
group.add_segment('x')


def test_mnl_dcm_segmented(grouped_choosers, alternatives):
def test_mnl_dcm_segmented(seed, grouped_choosers, alternatives):
model_exp = 'var2 + var1:var3'
sample_size = 4

Expand Down Expand Up @@ -338,10 +350,22 @@ def test_mnl_dcm_segmented(grouped_choosers, alternatives):
sprobs = group.summed_probabilities(grouped_choosers, alternatives)
assert len(sprobs) == len(alternatives)

choice_state = np.random.get_state()
choices = group.predict(grouped_choosers, alternatives)

assert len(choices) == len(grouped_choosers)
assert choices.isin(alternatives.index).all()
pdt.assert_series_equal(
choices,
pd.Series(['c', 'a', 'b', 'a', 'j'], index=[0, 2, 3, 1, 4]))

# check that we don't get the same alt twice if they are removed
# make sure we're starting from the same random state as the last draw
np.random.set_state(choice_state)
group._group.remove_alts = True
choices = group.predict(grouped_choosers, alternatives)

pdt.assert_series_equal(
choices,
pd.Series(['c', 'a', 'b', 'd', 'j'], index=[0, 2, 3, 1, 4]))


def test_mnl_dcm_segmented_yaml(grouped_choosers, alternatives):
Expand Down Expand Up @@ -370,6 +394,7 @@ def test_mnl_dcm_segmented_yaml(grouped_choosers, alternatives):
'default_config': {
'model_expression': model_exp,
},
'remove_alts': False,
'fitted': False,
'models': {
'x': {
Expand Down

0 comments on commit 282a59a

Please sign in to comment.