Skip to content

Commit

Permalink
CLV Plotting API (#728)
Browse files Browse the repository at this point in the history
* plot_probability_alive_matrix

* docstrings

* plot_frequency_recency_matrix

* delete dead code

* docstring quick fix
  • Loading branch information
ColtAllen authored Jun 10, 2024
1 parent f3be754 commit dc117be
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 78 deletions.
98 changes: 37 additions & 61 deletions pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _create_frequency_recency_meshes(

def plot_frequency_recency_matrix(
model: BetaGeoModel | ParetoNBDModel,
t=1,
future_t: int = 1,
max_frequency: int | None = None,
max_recency: int | None = None,
title: str | None = None,
Expand All @@ -182,20 +182,19 @@ def plot_frequency_recency_matrix(
**kwargs,
) -> plt.Axes:
"""
Plot recency frequency matrix as heatmap.
Plot a figure of expected transactions in T next units of time by a customer's frequency and recency.
Plot expected transactions in *future_t* time periods as a heatmap
based on customer population *frequency* and *recency*.
Parameters
----------
model: CLV model
A fitted CLV model.
t: float, optional
Next units of time to make predictions for
future_t: float, optional
Future time periods over which to run predictions.
max_frequency: int, optional
The maximum frequency to plot. Default is max observed frequency.
The maximum *frequency* to plot. Defaults to max observed *frequency*.
max_recency: int, optional
The maximum recency to plot. This also determines the age of the customer.
Default to max observed age.
The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*.
title: str, optional
Figure title
xlabel: str, optional
Expand All @@ -222,38 +221,25 @@ def plot_frequency_recency_matrix(
max_recency=max_recency,
)

# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
# We should harmonize them!
if isinstance(model, ParetoNBDModel):
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)
# create dataframe for model input
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)

Z = (
model.expected_purchases(
data=transaction_data,
future_t=t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_num_purchases(
customer_id=np.arange(mesh_recency.size), # placeholder
frequency=mesh_frequency.ravel(),
recency=mesh_recency.ravel(),
T=max_recency,
t=t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
# run model predictions to create heatmap values
Z = (
model.expected_purchases(
data=transaction_data,
future_t=future_t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

if ax is None:
ax = plt.subplot(111)
Expand All @@ -262,7 +248,7 @@ def plot_frequency_recency_matrix(
if title is None:
title = (
"Expected Number of Future Purchases for {} Unit{} of Time,".format(
t, "s"[t == 1 :]
future_t, "s"[future_t == 1 :]
)
+ "\nby Frequency and Recency of a Customer"
)
Expand Down Expand Up @@ -292,19 +278,16 @@ def plot_probability_alive_matrix(
**kwargs,
) -> plt.Axes:
"""
Plot probability alive matrix as heatmap.
Plot a figure of the probability a customer is alive based on their
frequency and recency.
Plot probability alive matrix as a heatmap based on customer population *frequency* and *recency*.
Parameters
----------
model: CLV model
A fitted CLV model.
max_frequency: int, optional
The maximum frequency to plot. Default is max observed frequency.
The maximum *frequency* to plot. Defaults to max observed *frequency*.
max_recency: int, optional
The maximum recency to plot. This also determines the age of the customer.
Default to max observed age.
The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*.
title: str, optional
Figure title
xlabel: str, optional
Expand All @@ -331,6 +314,8 @@ def plot_probability_alive_matrix(
max_frequency=max_frequency,
max_recency=max_recency,
)

# create dataframe for model input
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
Expand All @@ -339,22 +324,13 @@ def plot_probability_alive_matrix(
"T": max_recency,
}
)
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
if isinstance(model, ParetoNBDModel):
Z = (
model.expected_probability_alive(
data=transaction_data,
future_t=0, # TODO: This is a required parameter if data is provided.
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
else:
Z = (
model.expected_probability_alive(data=transaction_data)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

# run model predictions to create heatmap values
Z = (
model.expected_probability_alive(data=transaction_data)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)

interpolation = kwargs.pop("interpolation", "none")

Expand Down
22 changes: 5 additions & 17 deletions tests/clv/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,28 @@ class MockModel:
def __init__(self, data: pd.DataFrame):
self.data = data

def _mock_posterior(self, customer_id: np.ndarray | pd.Series) -> xr.DataArray:
n_customers = len(customer_id)
def _mock_posterior(self, data: pd.DataFrame) -> xr.DataArray:
n_customers = len(data)
n_chains = 4
n_draws = 10
chains = np.arange(n_chains)
draws = np.arange(n_draws)
return xr.DataArray(
data=np.ones((n_customers, n_chains, n_draws)),
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
coords={"customer_id": data["customer_id"], "chain": chains, "draw": draws},
dims=["customer_id", "chain", "draw"],
)

def expected_probability_alive(self, data: np.ndarray | pd.Series):
return self._mock_posterior(data["customer_id"])
return self._mock_posterior(data)

def expected_purchases(
self,
customer_id: np.ndarray | pd.Series,
data: pd.DataFrame,
*,
future_t: np.ndarray | pd.Series | TensorVariable,
):
return self._mock_posterior(customer_id)

# TODO: This is required until CLV API is standardized.
def expected_num_purchases(
self,
customer_id: np.ndarray | pd.Series,
t: np.ndarray | pd.Series | TensorVariable,
frequency: np.ndarray | pd.Series | TensorVariable,
recency: np.ndarray | pd.Series | TensorVariable,
T: np.ndarray | pd.Series | TensorVariable,
):
return self._mock_posterior(customer_id)
return self._mock_posterior(data)


@pytest.fixture
Expand Down

0 comments on commit dc117be

Please sign in to comment.