diff --git a/docs/source/notebooks/clv/bg_nbd.ipynb b/docs/source/notebooks/clv/bg_nbd.ipynb index 949071f66..5456b3453 100644 --- a/docs/source/notebooks/clv/bg_nbd.ipynb +++ b/docs/source/notebooks/clv/bg_nbd.ipynb @@ -5,56 +5,59 @@ "id": "51e3591e", "metadata": {}, "source": [ - "# BG/NBD model\n", + "# BG/NBD Model\n", "\n", - "Comparison with lifetimes" + "In this notebook we show how to fit a BG/NBD model in PyMC-Marketing. We compare the results with the [`lifetimes`](https://github.com/CamDavidsonPilon/lifetimes) package (no longer maintained). The model is presented in the paper: Fader, P. S., Hardie, B. G., & Lee, K. L. (2005). [“Counting your customers” the easy way: An alternative to the Pareto/NBD model. Marketing science, 24(2), 275-284.](http://www.brucehardie.com/papers/bgnbd_2004-04-20.pdf)" ] }, { "cell_type": "markdown", - "id": "80697def", + "id": "68f7ba7e", "metadata": {}, "source": [ - "**Reference**: Fader, P. S., Hardie, B. G., & Lee, K. L. (2005). “Counting your customers” the easy way: An alternative to the Pareto/NBD model. Marketing science, 24(2), 275-284.\n", - "\n", - "http://www.brucehardie.com/papers/bgnbd_2004-04-20.pdf" + "## Prepare Notebook" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "81c950fb", "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", "import pandas as pd\n", - "from lifetimes import BetaGeoFitter" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "8ee1b882", - "metadata": {}, - "outputs": [], - "source": [ - "from lifetimes.datasets import load_cdnow_summary" + "import xarray as xr\n", + "from fastprogress.fastprogress import progress_bar\n", + "from lifetimes import BetaGeoFitter\n", + "\n", + "from pymc_marketing import clv\n", + "\n", + "# Plotting configuration\n", + "az.style.use(\"arviz-darkgrid\")\n", + "plt.rcParams[\"figure.figsize\"] = [12, 7]\n", + "plt.rcParams[\"figure.dpi\"] = 100\n", + "plt.rcParams[\"figure.facecolor\"] = \"white\"\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%config InlineBackend.figure_format = \"retina\"" ] }, { - "cell_type": "code", - "execution_count": 3, - "id": "4a2dfdcd", + "cell_type": "markdown", + "id": "6e4b3b25", "metadata": {}, - "outputs": [], "source": [ - "from pymc_marketing import clv" + "## Read Data\n", + "\n", + "We use the `CDNOW` dataset (see lifetimes [quick-start](https://lifetimes.readthedocs.io/en/latest/Quickstart.html))." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "a99638b5", "metadata": {}, "outputs": [ @@ -82,153 +85,232 @@ "
<xarray.Dataset>\n", + "<xarray.Dataset> Size: 136kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", - " * chain (chain) int64 0 1 2 3\n", - " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n", + " * chain (chain) int64 32B 0 1 2 3\n", + " * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", "Data variables:\n", - " a (chain, draw) float64 0.822 0.6749 0.8323 ... 1.277 1.33 1.371\n", - " b (chain, draw) float64 2.499 2.024 2.501 2.605 ... 4.321 5.055 4.828\n", - " alpha (chain, draw) float64 4.078 4.307 4.878 4.738 ... 5.253 5.385 4.073\n", - " r (chain, draw) float64 0.2342 0.2255 0.247 ... 0.277 0.2572 0.221\n", + " a (chain, draw) float64 32kB 0.8328 0.9293 0.9555 ... 1.042 1.112\n", + " b (chain, draw) float64 32kB 2.824 2.756 2.953 ... 3.807 4.156 3.535\n", + " alpha (chain, draw) float64 32kB 4.397 4.364 4.305 ... 4.039 4.173 4.47\n", + " r (chain, draw) float64 32kB 0.2511 0.2528 0.2424 ... 0.2196 0.2261\n", "Attributes:\n", - " created_at: 2023-06-23T15:54:25.912633\n", - " arviz_version: 0.15.1\n", + " created_at: 2024-04-05T07:20:26.323594\n", + " arviz_version: 0.17.1\n", " inference_library: pymc\n", - " inference_library_version: 5.5.0\n", - " sampling_time: 10.613389253616333\n", - " tuning_steps: 1000
<xarray.Dataset>\n", + "<xarray.Dataset> Size: 496kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", - " * chain (chain) int64 0 1 2 3\n", - " * draw (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n", + " * chain (chain) int64 32B 0 1 2 3\n", + " * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999\n", "Data variables: (12/17)\n", - " reached_max_treedepth (chain, draw) bool False False False ... False False\n", - " index_in_trajectory (chain, draw) int64 -3 -3 8 2 -2 -2 ... -1 -4 3 3 11\n", - " energy_error (chain, draw) float64 -0.141 0.1389 ... -0.05982\n", - " perf_counter_diff (chain, draw) float64 0.001286 0.002343 ... 0.005221\n", - " acceptance_rate (chain, draw) float64 0.9927 0.9129 ... 0.965 0.9955\n", - " diverging (chain, draw) bool False False False ... False False\n", + " tree_depth (chain, draw) int64 32kB 4 2 4 4 2 4 ... 4 1 4 4 3 3\n", + " perf_counter_start (chain, draw) float64 32kB 1.8e+04 ... 1.801e+04\n", + " energy (chain, draw) float64 32kB 9.584e+03 ... 9.589e+03\n", + " index_in_trajectory (chain, draw) int64 32kB -2 -2 3 -3 -2 ... 0 6 -7 1 3\n", + " perf_counter_diff (chain, draw) float64 32kB 0.009088 ... 0.004634\n", + " step_size (chain, draw) float64 32kB 0.3874 0.3874 ... 0.4099\n", " ... ...\n", - " perf_counter_start (chain, draw) float64 1.824e+06 ... 1.824e+06\n", - " max_energy_error (chain, draw) float64 -0.141 0.1389 ... -0.1374\n", - " lp (chain, draw) float64 -9.582e+03 ... -9.585e+03\n", - " step_size_bar (chain, draw) float64 0.2761 0.2761 ... 0.3116 0.3116\n", - " largest_eigval (chain, draw) float64 nan nan nan nan ... nan nan nan\n", - " n_steps (chain, draw) float64 3.0 7.0 15.0 ... 15.0 7.0 15.0\n", + " smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan\n", + " diverging (chain, draw) bool 4kB False False ... False False\n", + " max_energy_error (chain, draw) float64 32kB 0.8383 0.3148 ... -1.02\n", + " lp (chain, draw) float64 32kB -9.583e+03 ... -9.585e+03\n", + " reached_max_treedepth (chain, draw) bool 4kB False False ... False False\n", + " acceptance_rate (chain, draw) float64 32kB 0.7259 0.9069 ... 0.9577\n", "Attributes:\n", - " created_at: 2023-06-23T15:54:25.921146\n", - " arviz_version: 0.15.1\n", + " created_at: 2024-04-05T07:20:26.344549\n", + " arviz_version: 0.17.1\n", " inference_library: pymc\n", - " inference_library_version: 5.5.0\n", - " sampling_time: 10.613389253616333\n", - " tuning_steps: 1000
<xarray.Dataset>\n", + "<xarray.Dataset> Size: 94kB\n", "Dimensions: (index: 2357)\n", "Coordinates:\n", - " * index (index) int64 0 1 2 3 4 5 6 ... 2351 2352 2353 2354 2355 2356\n", + " * index (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356\n", "Data variables:\n", - " customer_id (index) int64 0 1 2 3 4 5 6 ... 2351 2352 2353 2354 2355 2356\n", - " frequency (index) int64 2 1 0 0 0 7 1 0 2 0 5 0 ... 2 7 1 2 0 0 0 5 0 4 0\n", - " recency (index) float64 30.43 1.71 0.0 0.0 0.0 ... 24.29 0.0 26.57 0.0\n", - " T (index) float64 38.86 38.86 38.86 38.86 ... 27.0 27.0 27.0 27.0
<xarray.Dataset>\n", + "<xarray.Dataset> Size: 48B\n", "Dimensions: (chain: 1, draw: 1)\n", "Coordinates:\n", - " * chain (chain) int64 0\n", - " * draw (draw) int64 0\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 8B 0\n", "Data variables:\n", - " a (chain, draw) float64 0.793\n", - " b (chain, draw) float64 2.426\n", - " alpha (chain, draw) float64 4.414\n", - " r (chain, draw) float64 0.2426\n", + " a (chain, draw) float64 8B 0.793\n", + " b (chain, draw) float64 8B 2.426\n", + " alpha (chain, draw) float64 8B 4.414\n", + " r (chain, draw) float64 8B 0.2426\n", "Attributes:\n", - " created_at: 2023-06-23T15:54:27.552171\n", - " arviz_version: 0.15.1\n", + " created_at: 2024-04-05T07:20:31.004335\n", + " arviz_version: 0.17.1\n", " inference_library: pymc\n", - " inference_library_version: 5.5.0
<xarray.Dataset>\n", + "<xarray.Dataset> Size: 94kB\n", "Dimensions: (index: 2357)\n", "Coordinates:\n", - " * index (index) int64 0 1 2 3 4 5 6 ... 2351 2352 2353 2354 2355 2356\n", + " * index (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356\n", "Data variables:\n", - " customer_id (index) int64 0 1 2 3 4 5 6 ... 2351 2352 2353 2354 2355 2356\n", - " frequency (index) int64 2 1 0 0 0 7 1 0 2 0 5 0 ... 2 7 1 2 0 0 0 5 0 4 0\n", - " recency (index) float64 30.43 1.71 0.0 0.0 0.0 ... 24.29 0.0 26.57 0.0\n", - " T (index) float64 38.86 38.86 38.86 38.86 ... 27.0 27.0 27.0 27.0
\n", + " | customer_id | \n", + "frequency | \n", + "recency | \n", + "T | \n", + "
---|---|---|---|---|
1 | \n", + "1 | \n", + "1 | \n", + "1.71 | \n", + "38.86 | \n", + "
6 | \n", + "6 | \n", + "1 | \n", + "5.00 | \n", + "38.86 | \n", + "
10 | \n", + "10 | \n", + "5 | \n", + "24.43 | \n", + "38.86 | \n", + "
18 | \n", + "18 | \n", + "3 | \n", + "28.29 | \n", + "38.71 | \n", + "
45 | \n", + "45 | \n", + "12 | \n", + "34.43 | \n", + "38.57 | \n", + "
1412 | \n", + "1412 | \n", + "14 | \n", + "30.29 | \n", + "31.57 | \n", + "