From dc117beecf5a1518c0d5b1cd0fb9e9904c6d9ca9 Mon Sep 17 00:00:00 2001 From: Colt Allen <10178857+ColtAllen@users.noreply.github.com> Date: Mon, 10 Jun 2024 09:54:01 -0600 Subject: [PATCH] CLV Plotting API (#728) * plot_probability_alive_matrix * docstrings * plot_frequency_recency_matrix * delete dead code * docstring quick fix --- pymc_marketing/clv/plotting.py | 98 +++++++++++++--------------------- tests/clv/test_plotting.py | 22 ++------ 2 files changed, 42 insertions(+), 78 deletions(-) diff --git a/pymc_marketing/clv/plotting.py b/pymc_marketing/clv/plotting.py index bc63ee322..bb54b8982 100644 --- a/pymc_marketing/clv/plotting.py +++ b/pymc_marketing/clv/plotting.py @@ -172,7 +172,7 @@ def _create_frequency_recency_meshes( def plot_frequency_recency_matrix( model: BetaGeoModel | ParetoNBDModel, - t=1, + future_t: int = 1, max_frequency: int | None = None, max_recency: int | None = None, title: str | None = None, @@ -182,20 +182,19 @@ def plot_frequency_recency_matrix( **kwargs, ) -> plt.Axes: """ - Plot recency frequency matrix as heatmap. - Plot a figure of expected transactions in T next units of time by a customer's frequency and recency. + Plot expected transactions in *future_t* time periods as a heatmap + based on customer population *frequency* and *recency*. Parameters ---------- model: CLV model A fitted CLV model. - t: float, optional - Next units of time to make predictions for + future_t: float, optional + Future time periods over which to run predictions. max_frequency: int, optional - The maximum frequency to plot. Default is max observed frequency. + The maximum *frequency* to plot. Defaults to max observed *frequency*. max_recency: int, optional - The maximum recency to plot. This also determines the age of the customer. - Default to max observed age. + The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*. title: str, optional Figure title xlabel: str, optional @@ -222,38 +221,25 @@ def plot_frequency_recency_matrix( max_recency=max_recency, ) - # FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel - # We should harmonize them! - if isinstance(model, ParetoNBDModel): - transaction_data = pd.DataFrame( - { - "customer_id": np.arange(mesh_recency.size), # placeholder - "frequency": mesh_frequency.ravel(), - "recency": mesh_recency.ravel(), - "T": max_recency, - } - ) + # create dataframe for model input + transaction_data = pd.DataFrame( + { + "customer_id": np.arange(mesh_recency.size), # placeholder + "frequency": mesh_frequency.ravel(), + "recency": mesh_recency.ravel(), + "T": max_recency, + } + ) - Z = ( - model.expected_purchases( - data=transaction_data, - future_t=t, - ) - .mean(("draw", "chain")) - .values.reshape(mesh_recency.shape) - ) - else: - Z = ( - model.expected_num_purchases( - customer_id=np.arange(mesh_recency.size), # placeholder - frequency=mesh_frequency.ravel(), - recency=mesh_recency.ravel(), - T=max_recency, - t=t, - ) - .mean(("draw", "chain")) - .values.reshape(mesh_recency.shape) + # run model predictions to create heatmap values + Z = ( + model.expected_purchases( + data=transaction_data, + future_t=future_t, ) + .mean(("draw", "chain")) + .values.reshape(mesh_recency.shape) + ) if ax is None: ax = plt.subplot(111) @@ -262,7 +248,7 @@ def plot_frequency_recency_matrix( if title is None: title = ( "Expected Number of Future Purchases for {} Unit{} of Time,".format( - t, "s"[t == 1 :] + future_t, "s"[future_t == 1 :] ) + "\nby Frequency and Recency of a Customer" ) @@ -292,19 +278,16 @@ def plot_probability_alive_matrix( **kwargs, ) -> plt.Axes: """ - Plot probability alive matrix as heatmap. - Plot a figure of the probability a customer is alive based on their - frequency and recency. + Plot probability alive matrix as a heatmap based on customer population *frequency* and *recency*. Parameters ---------- model: CLV model A fitted CLV model. max_frequency: int, optional - The maximum frequency to plot. Default is max observed frequency. + The maximum *frequency* to plot. Defaults to max observed *frequency*. max_recency: int, optional - The maximum recency to plot. This also determines the age of the customer. - Default to max observed age. + The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*. title: str, optional Figure title xlabel: str, optional @@ -331,6 +314,8 @@ def plot_probability_alive_matrix( max_frequency=max_frequency, max_recency=max_recency, ) + + # create dataframe for model input transaction_data = pd.DataFrame( { "customer_id": np.arange(mesh_recency.size), # placeholder @@ -339,22 +324,13 @@ def plot_probability_alive_matrix( "T": max_recency, } ) - # FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel - if isinstance(model, ParetoNBDModel): - Z = ( - model.expected_probability_alive( - data=transaction_data, - future_t=0, # TODO: This is a required parameter if data is provided. - ) - .mean(("draw", "chain")) - .values.reshape(mesh_recency.shape) - ) - else: - Z = ( - model.expected_probability_alive(data=transaction_data) - .mean(("draw", "chain")) - .values.reshape(mesh_recency.shape) - ) + + # run model predictions to create heatmap values + Z = ( + model.expected_probability_alive(data=transaction_data) + .mean(("draw", "chain")) + .values.reshape(mesh_recency.shape) + ) interpolation = kwargs.pop("interpolation", "none") diff --git a/tests/clv/test_plotting.py b/tests/clv/test_plotting.py index 7bc3ed297..52aeec7cd 100644 --- a/tests/clv/test_plotting.py +++ b/tests/clv/test_plotting.py @@ -29,40 +29,28 @@ class MockModel: def __init__(self, data: pd.DataFrame): self.data = data - def _mock_posterior(self, customer_id: np.ndarray | pd.Series) -> xr.DataArray: - n_customers = len(customer_id) + def _mock_posterior(self, data: pd.DataFrame) -> xr.DataArray: + n_customers = len(data) n_chains = 4 n_draws = 10 chains = np.arange(n_chains) draws = np.arange(n_draws) return xr.DataArray( data=np.ones((n_customers, n_chains, n_draws)), - coords={"customer_id": customer_id, "chain": chains, "draw": draws}, + coords={"customer_id": data["customer_id"], "chain": chains, "draw": draws}, dims=["customer_id", "chain", "draw"], ) def expected_probability_alive(self, data: np.ndarray | pd.Series): - return self._mock_posterior(data["customer_id"]) + return self._mock_posterior(data) def expected_purchases( self, - customer_id: np.ndarray | pd.Series, data: pd.DataFrame, *, future_t: np.ndarray | pd.Series | TensorVariable, ): - return self._mock_posterior(customer_id) - - # TODO: This is required until CLV API is standardized. - def expected_num_purchases( - self, - customer_id: np.ndarray | pd.Series, - t: np.ndarray | pd.Series | TensorVariable, - frequency: np.ndarray | pd.Series | TensorVariable, - recency: np.ndarray | pd.Series | TensorVariable, - T: np.ndarray | pd.Series | TensorVariable, - ): - return self._mock_posterior(customer_id) + return self._mock_posterior(data) @pytest.fixture