Skip to content

Commit

Permalink
MMM Component Notebook (#748)
Browse files Browse the repository at this point in the history
* initial notebook

* add to example

* push up some feedback

* change the number of channels

* updates

* final message

* add more feedback

* remove since wasnt working the way I wanted
  • Loading branch information
wd60622 authored Jun 18, 2024
1 parent 0144bd3 commit c994210
Show file tree
Hide file tree
Showing 4 changed files with 2,756 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/notebooks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mmm/mmm_example
mmm/mmm_budget_allocation_example
mmm/mmm_lift_test
mmm/mmm_tvp_example
mmm/mmm_components
:::

:::{toctree}
Expand Down
2,725 changes: 2,725 additions & 0 deletions docs/source/notebooks/mmm/mmm_components.ipynb

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion tests/mmm/components/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_selections(coords, expected) -> None:
assert list(selections(coords)) == expected


def test_change_instance_function_priors_has_no_impact(
def test_change_instance_function_priors_has_no_impact_new_instance(
new_transformation_class,
) -> None:
"""What happens in the MMM logic."""
Expand All @@ -297,3 +297,17 @@ def test_change_instance_function_priors_has_no_impact(
"a": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
"b": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
}


def test_change_instance_function_priors_has_no_impact_on_class(
new_transformation_class,
) -> None:
instance = new_transformation_class()

for _, config in instance.function_priors.items():
config["dims"] = "channel"

assert new_transformation_class.default_priors == {
"a": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
"b": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
}
15 changes: 15 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,18 @@ def test_initialize_alternative_with_classes() -> None:
assert isinstance(mmm.adstock, DelayedAdstock)
assert mmm.adstock.l_max == 10
assert isinstance(mmm.saturation, MichaelisMentenSaturation)


def test_initialize_defaults_channel_media_dims() -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock=DelayedAdstock(l_max=10),
saturation=MichaelisMentenSaturation(),
)

for transform in [mmm.adstock, mmm.saturation]:
for config in transform.function_priors.values():
assert config["dims"] == "channel"

0 comments on commit c994210

Please sign in to comment.