Skip to content

Commit

Permalink
created inverse_scaled_logistic_saturation and the corresponding class
Browse files Browse the repository at this point in the history
  • Loading branch information
Arthur authored and arthurmello committed Jul 11, 2024
1 parent 717702a commit 0baf5f4
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from pymc_marketing.mmm.components.saturation import (
HillSaturation,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
SaturationTransformation,
Expand All @@ -45,6 +46,7 @@
"GeometricAdstock",
"HillSaturation",
"LogisticSaturation",
"InverseScaledLogisticSaturation",
"MMM",
"MMMModelBuilder",
"MichaelisMentenSaturation",
Expand Down
35 changes: 35 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def function(self, x, b):
from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
hill_saturation,
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
tanh_saturation,
Expand Down Expand Up @@ -201,6 +202,39 @@ def function(self, x, lam, beta):
}


class InverseScaledLogisticSaturation(SaturationTransformation):
"""Wrapper around inverse scaled logistic saturation function.
For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import InverseScaledLogisticSaturation
rng = np.random.default_rng(0)
adstock = InverseScaledLogisticSaturation()
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "inverse_scaled_logistic"

def function(self, x, lam, beta):
return beta * inverse_scaled_logistic_saturation(x, lam)

default_priors = {
"lam": Prior("Gamma", alpha=3, beta=1),
"beta": Prior("HalfNormal", sigma=2),
}


class TanhSaturation(SaturationTransformation):
"""Wrapper around tanh saturation function.
Expand Down Expand Up @@ -339,6 +373,7 @@ class HillSaturation(SaturationTransformation):
cls.lookup_name: cls
for cls in [
LogisticSaturation,
InverseScaledLogisticSaturation,
TanhSaturation,
TanhSaturationBaselined,
MichaelisMentenSaturation,
Expand Down
46 changes: 46 additions & 0 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,52 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):
return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x))


def inverse_scaled_logistic_saturation(
x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3)
):
"""Inverse scaled logistic saturation transformation.
.. math::
f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}}
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation
plt.style.use('arviz-darkgrid')
lam = np.array([0.25, 0.5, 1, 2, 4])
x = np.linspace(0, 5, 100)
ax = plt.subplot(111)
for l in lam:
y = inverse_scaled_logistic_saturation(x, lam=l).eval()
plt.plot(x, y, label=f'lam = {l}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
Parameters
----------
x : tensor
Input tensor.
lam : float or array-like, optional, by default 0.5
Saturation parameter.
eps : float or array-like, optional, by default ln(3)
Scaling parameter.
Returns
-------
tensor
Transformed tensor.
""" # noqa: W605
return logistic_saturation(x, eps / lam)


class TanhSaturationParameters(NamedTuple):
"""Container for tanh saturation parameters.
Expand Down
3 changes: 3 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pymc_marketing.mmm.components.saturation import (
HillSaturation,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
TanhSaturation,
Expand All @@ -40,6 +41,7 @@ def model() -> pm.Model:
def saturation_functions():
return [
LogisticSaturation(),
InverseScaledLogisticSaturation(),
TanhSaturation(),
TanhSaturationBaselined(),
MichaelisMentenSaturation(),
Expand Down Expand Up @@ -93,6 +95,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
@pytest.mark.parametrize(
"name, saturation_cls",
[
("inverse_scaled_logistic", InverseScaledLogisticSaturation),
("logistic", LogisticSaturation),
("tanh", TanhSaturation),
("tanh_baselined", TanhSaturationBaselined),
Expand Down

0 comments on commit 0baf5f4

Please sign in to comment.