From 0baf5f48b9baa6761c3fef36d22a298c5f11ecdd Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 10 Jul 2024 20:09:54 +0200 Subject: [PATCH] created inverse_scaled_logistic_saturation and the corresponding class --- pymc_marketing/mmm/__init__.py | 2 + pymc_marketing/mmm/components/saturation.py | 35 ++++++++++++++++ pymc_marketing/mmm/transformers.py | 46 +++++++++++++++++++++ tests/mmm/components/test_saturation.py | 3 ++ 4 files changed, 86 insertions(+) diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index 2d75a1b07..dafc319f2 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -23,6 +23,7 @@ ) from pymc_marketing.mmm.components.saturation import ( HillSaturation, + InverseScaledLogisticSaturation, LogisticSaturation, MichaelisMentenSaturation, SaturationTransformation, @@ -45,6 +46,7 @@ "GeometricAdstock", "HillSaturation", "LogisticSaturation", + "InverseScaledLogisticSaturation", "MMM", "MMMModelBuilder", "MichaelisMentenSaturation", diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index d93f8ca60..0d2f31063 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -76,6 +76,7 @@ def function(self, x, b): from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( hill_saturation, + inverse_scaled_logistic_saturation, logistic_saturation, michaelis_menten, tanh_saturation, @@ -201,6 +202,39 @@ def function(self, x, lam, beta): } +class InverseScaledLogisticSaturation(SaturationTransformation): + """Wrapper around inverse scaled logistic saturation function. + + For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`. + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import InverseScaledLogisticSaturation + + rng = np.random.default_rng(0) + + adstock = InverseScaledLogisticSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + + """ + + lookup_name = "inverse_scaled_logistic" + + def function(self, x, lam, beta): + return beta * inverse_scaled_logistic_saturation(x, lam) + + default_priors = { + "lam": Prior("Gamma", alpha=3, beta=1), + "beta": Prior("HalfNormal", sigma=2), + } + + class TanhSaturation(SaturationTransformation): """Wrapper around tanh saturation function. @@ -339,6 +373,7 @@ class HillSaturation(SaturationTransformation): cls.lookup_name: cls for cls in [ LogisticSaturation, + InverseScaledLogisticSaturation, TanhSaturation, TanhSaturationBaselined, MichaelisMentenSaturation, diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 5c036445c..bdf80c7f4 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -478,6 +478,52 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5): return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x)) +def inverse_scaled_logistic_saturation( + x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3) +): + """Inverse scaled logistic saturation transformation. + + .. math:: + f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}} + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + import arviz as az + from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation + plt.style.use('arviz-darkgrid') + lam = np.array([0.25, 0.5, 1, 2, 4]) + x = np.linspace(0, 5, 100) + ax = plt.subplot(111) + for l in lam: + y = inverse_scaled_logistic_saturation(x, lam=l).eval() + plt.plot(x, y, label=f'lam = {l}') + plt.xlabel('spend', fontsize=12) + plt.ylabel('f(spend)', fontsize=12) + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + plt.show() + + Parameters + ---------- + x : tensor + Input tensor. + lam : float or array-like, optional, by default 0.5 + Saturation parameter. + eps : float or array-like, optional, by default ln(3) + Scaling parameter. + + Returns + ------- + tensor + Transformed tensor. + """ # noqa: W605 + return logistic_saturation(x, eps / lam) + + class TanhSaturationParameters(NamedTuple): """Container for tanh saturation parameters. diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index fc78b3628..cea753912 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -22,6 +22,7 @@ from pymc_marketing.mmm.components.saturation import ( HillSaturation, + InverseScaledLogisticSaturation, LogisticSaturation, MichaelisMentenSaturation, TanhSaturation, @@ -40,6 +41,7 @@ def model() -> pm.Model: def saturation_functions(): return [ LogisticSaturation(), + InverseScaledLogisticSaturation(), TanhSaturation(), TanhSaturationBaselined(), MichaelisMentenSaturation(), @@ -93,6 +95,7 @@ def test_support_for_lift_test_integrations(saturation) -> None: @pytest.mark.parametrize( "name, saturation_cls", [ + ("inverse_scaled_logistic", InverseScaledLogisticSaturation), ("logistic", LogisticSaturation), ("tanh", TanhSaturation), ("tanh_baselined", TanhSaturationBaselined),