-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make optimization based on the entire posterior and not on the marginal mean parameters. #1151
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1151 +/- ##
==========================================
- Coverage 95.65% 95.27% -0.38%
==========================================
Files 39 40 +1
Lines 4096 4234 +138
==========================================
+ Hits 3918 4034 +116
- Misses 178 200 +22 ☔ View full report in Codecov by Sentry. |
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-12T20:06:43Z Line #3. plt.title("Response Distribution at 95% Confidence Level"); We should use the term HDI (highest density interval) |
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-12T20:06:44Z Same comment bout the HDI in the title
It would be also nice to have them both plotten in the same figure to compare them :) |
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-12T20:06:45Z Same comment about the title |
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-12T20:06:46Z Did anything actually change in the example notebook or can we simply remove the changes ? |
f18e215
to
e8e9588
Compare
All notebooks update to use the new optimization method. |
Ready for final review, everything should be done. |
Thank @cetagostini ! I will review this one soon (I promise). One thing you could improve while we do it is making sure we test all functions. For example, I see from the coverage report we do not have a test for the function ( |
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-29T18:59:16Z
|
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-29T18:59:17Z We shoudl revert this as we want to use the Prior class |
View / edit / reply to this conversation on ReviewNB juanitorduz commented on 2024-11-29T18:59:18Z wrong correction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cetagostini can we revert the examples in the example notebook? This seems like an old version and has many errors and typos 🙏 (can you please check the other notebooks in detail as well to see iv we have the same issue?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to give a detailed look and left some comments 🙏
"method": "SLSQP", | ||
"options": {"ftol": 1e-9, "maxiter": 1_000}, | ||
} | ||
minimize_kwargs = self.DEFAULT_MINIMIZE_KWARGS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we consider #1193 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like it, will apply it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mypy doesn't like:
if minimize_kwargs is None:
minimize_kwargs = {
**self.DEFAULT_MINIMIZE_KWARGS,
**(minimize_kwargs or {}),
}
Change to
if minimize_kwargs is None:
minimize_kwargs = self.DEFAULT_MINIMIZE_KWARGS.copy()
else:
minimize_kwargs = {**self.DEFAULT_MINIMIZE_KWARGS, **minimize_kwargs}
pymc_marketing/mmm/utility.py
Outdated
|
||
import pytensor.tensor as pt | ||
|
||
UtilityFunction = Callable[[pt.TensorVariable, pt.TensorVariable], float] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we call this UtilityFunctionType
? Otherwise it could be interpreted as a class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can do that!
def average_response( | ||
samples: pt.TensorVariable, budgets: pt.TensorVariable | ||
) -> pt.TensorVariable: | ||
"""Compute the average response of the posterior predictive distribution.""" | ||
return pt.mean(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this function? Note we are not using the budget
variable at all, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the same signature in all functions in order to make compatible an scalable. I had an implementation that was more directed per function but was taking a few flexibility from users, with @wd60622 we arrive to this conclusion with makes everything more flexible, the only inconvenience is we need samples
and budgets
across all, even when we don't use them.
raise ValueError("Confidence level must be between 0 and 1.") | ||
|
||
def _tail_distance( | ||
samples: pt.TensorVariable, budgets: pt.TensorVariable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are not using the budgets
variable, so we should remove it right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replied before!
---------- | ||
.. [1] Rockafellar, R.T., & Uryasev, S. (2000). Optimization of Conditional Value-at-Risk. | ||
""" | ||
if not 0 < confidence_level < 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any problems if is exactly zero or one?
|
||
Parameters | ||
---------- | ||
confidence_level : float, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add "confidence level must be between 0 and 1."
"channel_1": { | ||
"adstock_params": {"alpha": 0.5}, | ||
"saturation_params": {"lam": 10, "beta": 0.5}, | ||
"saturation_params": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we still test the scalar case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean here? What scalar case?
|
||
rng: np.random.Generator = np.random.default_rng(seed=42) | ||
|
||
EXPECTED_RESULTS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where are these values coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a notebook to check what would be the results for each function given the parameters, this test validates they return those values every time, meaning, function behavior its not changing and return consistent results.
UtilityFunction | ||
A function that calculates the tail distance metric given samples and budgets. | ||
""" | ||
if not 0 < confidence_level < 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a test fro these ValueErrors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we could use @validate_call
from pydantic as in https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/mmm/mmm.py#L74
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test in place!
Important Comments:
Thank you for this great and ambitious PR, @cetagostini ! These suggestions are to make sure we deliver a great solution 💪 🙇 |
Thank you for the detailed review, I need some feedback on a few comments, I'll proceed with the rest 🙌🏻 PS: Sorry for the MMM example, I got issues with the notebook and maybe the rebase revert those, let me bring back the one from main @juanitorduz |
Thanks @cetagostini ! Git and notebooks is a pain!👌🙇 (btw cool pytensor stuff... I am learning quite a lot from these risks metrics!) |
@juanitorduz @wd60622 changes applied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @carlosagostini ! This is a great addition! I think we should merge this one and iterate. 🚀
Description
Structural change on how responses are computed in the optimizer, we are now using the entire "posterior" of the model, and additionally we are creating a new notebook where we can encode certain risk levels.
Related Issue
Checklist
Modules affected
Type of change
📚 Documentation preview 📚: https://pymc-marketing--1151.org.readthedocs.build/en/1151/