-
Notifications
You must be signed in to change notification settings - Fork 713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjust ate inference to get the exact stderr when final stage is linear. #418
Conversation
1d776c9
to
6580770
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, mostly minor. But shouldn't there also be something that tests that we get back a non-None mean_stderr for the linear case?
Sure. I will add a test around that. I also added some context about this PR to explain why I do it in this way. Please also help confirm the whether logic seems right to you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
d95f84a
to
4c28958
Compare
Add some context for this implementation:
In order to get the analytical inference of E(theta(x)), we need to know stderr of E(theta(x)), however previously we could only get the conservative upper bound of it. Given the fact that only in linear cate scenario, E(theta(x))=theta(E(x)), and we could get STD(theta(E(x))) by calling
prediction_stderr(X.mean(axis=0))
, which will return the exact standard error of the mean.For
const_marginal_effect
we learn theta(x) alone, then the stderr could be adjusted when final stage is linear in both continues and categorical scenario. Foreffect
, in continues scenario we learn cross product of X and T (as XT), we get the stderr by callingprediction_stderr(XT.mean(axis=0))
, in categorical scenario we learn the effect of a specific treatment comparing to baseline, for each sample, the effect could be different given what the treatment is. We couldn't learn theta(XT.mean(axis=0)) unless T1 is a constant or a list of constant. So I only adjusteffect_inference
only when T1 is a constant or list of constant.