Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove warnings during tests #823

Merged
merged 11 commits into from
Jul 11, 2024
18 changes: 12 additions & 6 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_config import ModelConfig, parse_model_config
from pymc_marketing.utils import from_netcdf


class CLVModel(ModelBuilder):
Expand Down Expand Up @@ -186,17 +187,22 @@ def load(cls, fname: str):
>>> imported_model = MyModel.load(name)
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
idata = from_netcdf(filepath)
return cls._build_with_idata(idata)

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
)
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
model.build_model() # type: ignore
if model.id != idata.attrs["id"]:
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _extract_predictive_variables(
purchase_coefficient = self.fit_result["purchase_coefficient"]
alpha = alpha_scale * np.exp(
-xarray.dot(
purchase_coefficient, purchase_xarray, dims="purchase_covariate"
purchase_coefficient, purchase_xarray, dim="purchase_covariate"
)
)
alpha.name = "alpha"
Expand All @@ -429,7 +429,7 @@ def _extract_predictive_variables(
dropout_coefficient = self.fit_result["dropout_coefficient"]
beta = beta_scale * np.exp(
-xarray.dot(
dropout_coefficient, dropout_xarray, dims="dropout_covariate"
dropout_coefficient, dropout_xarray, dim="dropout_covariate"
)
)
beta.name = "beta"
Expand Down
3 changes: 1 addition & 2 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(
),
model_config: dict | None = Field(None, description="Model configuration."),
sampler_config: dict | None = Field(None, description="Sampler configuration."),
**kwargs,
) -> None:
self.date_column: str = date_column
self.channel_columns: list[str] | tuple[str] = channel_columns
Expand Down Expand Up @@ -392,7 +391,7 @@ def plot_posterior_predictive(

if original_scale:
likelihood_hdi = self.get_target_transformer().inverse_transform(
Xt=likelihood_hdi
likelihood_hdi
)

ax.fill_between(
Expand Down
12 changes: 5 additions & 7 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,19 @@
"No budget bounds provided. Using default bounds (0, total_budget) for each channel.",
stacklevel=2,
)
else:
if not isinstance(budget_bounds, dict):
raise TypeError("`budget_bounds` should be a dictionary.")
elif not isinstance(budget_bounds, dict):
raise TypeError("`budget_bounds` should be a dictionary.")

Check warning on line 171 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L171

Added line #L171 was not covered by tests

if custom_constraints is None:
constraints = {"type": "eq", "fun": lambda x: np.sum(x) - total_budget}
warnings.warn(
"Using default equality constraint: The sum of all budgets should be equal to the total budget.",
stacklevel=2,
)
elif not isinstance(custom_constraints, dict):
raise TypeError("`custom_constraints` should be a dictionary.")

Check warning on line 180 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L180

Added line #L180 was not covered by tests
else:
if not isinstance(custom_constraints, dict):
raise TypeError("`custom_constraints` should be a dictionary.")
else:
constraints = custom_constraints
constraints = custom_constraints

num_channels = len(self.parameters.keys())
initial_guess = np.ones(num_channels) * total_budget / num_channels
Expand Down
97 changes: 63 additions & 34 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
GeometricAdstock,
_get_adstock_function,
)
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
SaturationTransformation,
_get_saturation_function,
)
Expand All @@ -54,6 +56,7 @@
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior
from pymc_marketing.utils import from_netcdf

__all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"]

Expand All @@ -78,17 +81,20 @@ def __init__(
channel_columns: list[str] = Field(
min_length=1, description="Column names of the media channel variables."
),
adstock_max_lag: int = Field(
...,
gt=0,
description="Number of lags to consider in the adstock transformation.",
),
adstock: str | InstanceOf[AdstockTransformation] = Field(
..., description="Type of adstock transformation to apply."
),
saturation: str | InstanceOf[SaturationTransformation] = Field(
..., description="Type of saturation transformation to apply."
),
adstock_max_lag: int | None = Field(
None,
gt=0,
description=(
"Number of lags to consider in the adstock transformation. "
"Defaults to the max lag of the adstock transformation."
),
),
time_varying_intercept: bool = Field(
False, description="Whether to consider time-varying intercept."
),
Expand Down Expand Up @@ -118,7 +124,6 @@ def __init__(
adstock_first: bool = Field(
True, description="Whether to apply adstock first."
),
**kwargs,
) -> None:
"""Constructor method.

Expand All @@ -128,12 +133,13 @@ def __init__(
Column name of the date variable. Must be parsable using ~pandas.to_datetime.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int, optional
Number of lags to consider in the adstock transformation.
adstock : str | AdstockTransformation
Type of adstock transformation to apply.
saturation : str | SaturationTransformation
Type of saturation transformation to apply.
adstock_max_lag : int, optional
Number of lags to consider in the adstock transformation. Defaults to the
max lag of the adstock transformation.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
Because the `time-varying` variable is centered around 1 and acts as a multiplier,
Expand All @@ -159,14 +165,24 @@ def __init__(
Whether to apply adstock first, by default True.
"""
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_intercept = time_varying_intercept
self.time_varying_media = time_varying_media
self.date_column = date_column
self.validate_data = validate_data

self.adstock_first = adstock_first
self.adstock = _get_adstock_function(function=adstock, l_max=adstock_max_lag)

if adstock_max_lag is not None:
warnings.warn(
"The `adstock_max_lag` parameter is deprecated. Use `adstock` directly",
DeprecationWarning,
stacklevel=1,
)
adstock_kwargs = {"l_max": adstock_max_lag}
else:
adstock_kwargs = {}

self.adstock = _get_adstock_function(function=adstock, **adstock_kwargs)
self.saturation = _get_saturation_function(function=saturation)

model_config = model_config or {}
Expand All @@ -184,7 +200,6 @@ def __init__(
channel_columns=channel_columns,
model_config=model_config,
sampler_config=sampler_config,
adstock_max_lag=adstock_max_lag,
)

self.yearly_seasonality = yearly_seasonality
Expand Down Expand Up @@ -287,7 +302,7 @@ def _save_input_params(self, idata) -> None:
idata.attrs["adstock_first"] = json.dumps(self.adstock_first)
idata.attrs["control_columns"] = json.dumps(self.control_columns)
idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
Expand Down Expand Up @@ -358,7 +373,11 @@ def build_model(

.. code-block:: python

from pymc_marketing.mmm import MMM
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation
MMM,
)
from pymc_marketing.prior import Prior

custom_config = {
Expand All @@ -374,12 +393,13 @@ def build_model(
model = MMM(
date_column="date_week",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config=custom_config,
)
Expand Down Expand Up @@ -634,7 +654,7 @@ def load(cls, fname: str):
"""

filepath = Path(fname)
idata = az.from_netcdf(filepath)
idata = from_netcdf(filepath)
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
Expand Down Expand Up @@ -844,22 +864,25 @@ class MMM(
import numpy as np
import pandas as pd

from pymc_marketing.mmm import MMM
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation
MMM,
)

data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
data = pd.read_csv(data_url, parse_dates=["date_week"])

mmm = MMM(
date_column="date_week",
adstock="geometric",
saturation="logistic",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
)

Expand All @@ -880,27 +903,30 @@ class MMM(

import numpy as np

from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation
MMM,
)
from pymc_marketing.prior import Prior
from pymc_marketing.mmm import MMM

my_model_config = {
"beta_channel": Prior("LogNormal", mu=np.array([2, 1]), sigma=1),
"likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=2)),
}

mmm = MMM(
adstock="geometric",
saturation="logistic",
model_config=my_model_config,
date_column="date_week",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config=my_model_config,
)

As you can see, we can configure all prior and likelihood distributions via the `model_config`.
Expand Down Expand Up @@ -1765,18 +1791,21 @@ def add_lift_test_measurements(
import pandas as pd
import numpy as np

from pymc_marketing.mmm import MMM
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation,
MMM,
)

model = MMM(
adstock="geometric",
saturation="logistic",
date_column="date_week",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
],
adstock_max_lag=8,
yearly_seasonality=2,
)

Expand Down Expand Up @@ -2206,7 +2235,6 @@ def __init__(
control_columns: list[str] | None = None,
yearly_seasonality: int | None = None,
adstock_first: bool = True,
**kwargs,
) -> None:
"""
Wrapper function for DelayedSaturatedMMM class initializer.
Expand All @@ -2217,22 +2245,23 @@ def __init__(
warnings.warn(
"The DelayedSaturatedMMM class is deprecated. Please use the MMM class instead.",
DeprecationWarning,
stacklevel=2,
stacklevel=1,
)

adstock = GeometricAdstock(l_max=adstock_max_lag)
saturation = LogisticSaturation()

super().__init__(
date_column=date_column,
channel_columns=channel_columns,
adstock_max_lag=adstock_max_lag,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
model_config=model_config,
sampler_config=sampler_config,
validate_data=validate_data,
control_columns=control_columns,
yearly_seasonality=yearly_seasonality,
adstock="geometric",
saturation="logistic",
adstock=adstock,
saturation=saturation,
adstock_first=adstock_first,
**kwargs,
)
Loading
Loading