Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust ate inference to get the exact stderr when final stage is linear. #418

Merged
merged 13 commits into from
Mar 3, 2021
Merged
4 changes: 2 additions & 2 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def const_marginal_effect_inference(self, X):
pred = pred.reshape((-1,) + self._d_y + self._d_t)
pred_stderr = np.sqrt(np.diagonal(pred_var, axis1=2, axis2=3).reshape((-1,) + self._d_y + self._d_t))
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect')
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect')

def effect_interval(self, X, *, T0, T1, alpha=0.1):
return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha)
Expand All @@ -97,7 +97,7 @@ def effect_inference(self, X, *, T0, T1):
pred = pred.reshape((-1,) + self._d_y)
pred_stderr = np.sqrt(pred_var.reshape((-1,) + self._d_y))
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect')
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect')


class CausalForestDML(_BaseDML):
Expand Down
2 changes: 1 addition & 1 deletion econml/inference/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def normal_inference(*args, **kwargs):
stderr = stderr(*args, **kwargs)
return NormalInferenceResults(
d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
pred_stderr=stderr, mean_pred_stderr=None, inf_type=inf_type,
fname_transformer=fname_transformer,
feature_names=self._wrapped.cate_feature_names(),
output_names=self._wrapped.cate_output_names(),
Expand Down
114 changes: 83 additions & 31 deletions econml/inference/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def const_marginal_effect_inference(self, X):
warn("Final model doesn't have a `prediction_stderr` method, "
"only point estimates will be returned.")
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect',
pred_stderr=pred_stderr, mean_pred_stderr=None, inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names(),
treatment_names=self._est.cate_treatment_names())
Expand Down Expand Up @@ -193,9 +193,10 @@ def effect_inference(self, X, *, T0, T1):
e_pred = np.einsum(einsum_str, cme_pred, dT)
e_stderr = np.einsum(einsum_str, cme_stderr, np.abs(dT)) if cme_stderr is not None else None
d_y = self._d_y[0] if self._d_y else 1

# d_t=None here since we measure the effect across all Ts
return NormalInferenceResults(d_t=None, d_y=d_y, pred=e_pred,
pred_stderr=e_stderr, inf_type='effect',
pred_stderr=e_stderr, mean_pred_stderr=None, inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())

Expand Down Expand Up @@ -240,15 +241,38 @@ def effect_inference(self, X, *, T0, T1):
X = np.ones((T0.shape[0], 1))
elif self.featurizer is not None:
X = self.featurizer.transform(X)
e_pred = self._predict(cross_product(X, T1 - T0))
e_stderr = self._prediction_stderr(cross_product(X, T1 - T0))
XT = cross_product(X, T1 - T0)
e_pred = self._predict(XT)
e_stderr = self._prediction_stderr(XT)
d_y = self._d_y[0] if self._d_y else 1

mean_XT = XT.mean(axis=0, keepdims=True)
mean_pred_stderr = self._prediction_stderr(mean_XT) # shape[0] will always be 1 here
# squeeze the first axis
mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0) if mean_pred_stderr is not None else None
# d_t=None here since we measure the effect across all Ts
return NormalInferenceResults(d_t=None, d_y=d_y, pred=e_pred,
pred_stderr=e_stderr, inf_type='effect',
pred_stderr=e_stderr, mean_pred_stderr=mean_pred_stderr, inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())

def const_marginal_effect_inference(self, X):
inf_res = super().const_marginal_effect_inference(X)

# set the mean_pred_stderr
if X is None:
X = np.ones((1, 1))
elif self.featurizer is not None:
X = self.featurizer.transform(X)
X_mean, T_mean = broadcast_unit_treatments(X.mean(axis=0).reshape(1, -1), self.d_t)
mean_XT = cross_product(X_mean, T_mean)
mean_pred_stderr = self._prediction_stderr(mean_XT)
if mean_pred_stderr is not None:
mean_pred_stderr = reshape_treatmentwise_effects(mean_pred_stderr,
self._d_t, self._d_y) # shape[0] will always be 1 here
inf_res.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
return inf_res

def coef__interval(self, *, alpha=0.1):
lo, hi = self.model_final.coef__interval(alpha)
lo_int, hi_int = self.model_final.intercept__interval(alpha)
Expand Down Expand Up @@ -285,6 +309,7 @@ def coef__inference(self):
fname_transformer = self._est.cate_feature_names

return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=coef, pred_stderr=coef_stderr,
mean_pred_stderr=None,
inf_type='coefficient', fname_transformer=fname_transformer,
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names(),
Expand Down Expand Up @@ -323,6 +348,7 @@ def intercept__inference(self):
intercept_stderr = None

return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=intercept, pred_stderr=intercept_stderr,
mean_pred_stderr=None,
inf_type='intercept',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names(),
Expand Down Expand Up @@ -380,11 +406,7 @@ def fit(self, estimator, *args, **kwargs):
self.fit_cate_intercept = estimator.fit_cate_intercept

def const_marginal_effect_interval(self, X, *, alpha=0.1):
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
preds = np.array([tuple(map(lambda x: x.reshape((-1,) + self._d_y), mdl.predict_interval(X, alpha=alpha)))
for mdl in self.fitted_models_final])
return tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
return self.const_marginal_effect_inference(X).conf_int(alpha=alpha)

def const_marginal_effect_inference(self, X):
if (X is not None) and (self.featurizer is not None):
Expand All @@ -401,22 +423,14 @@ def const_marginal_effect_inference(self, X):
"Only point estimates will be available.")
pred_stderr = None
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect',
pred_stderr=pred_stderr, mean_pred_stderr=None,
inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names(),
treatment_names=self._est.cate_treatment_names())

def effect_interval(self, X, *, T0, T1, alpha=0.1):
X, T0, T1 = self._est._expand_treatments(X, T0, T1)
if np.any(np.any(T0 > 0, axis=1)):
raise AttributeError("Can only calculate intervals of effects with respect to baseline treatment!")
ind = inverse_onehot(T1)
lower, upper = self.const_marginal_effect_interval(X, alpha=alpha)
lower = np.concatenate([np.zeros(lower.shape[0:-1] + (1,)), lower], -1)
upper = np.concatenate([np.zeros(upper.shape[0:-1] + (1,)), upper], -1)
if X is None: # Then const_marginal_effect_interval will return a single row
lower, upper = np.repeat(lower, T0.shape[0], axis=0), np.repeat(upper, T0.shape[0], axis=0)
return lower[np.arange(T0.shape[0]), ..., ind], upper[np.arange(T0.shape[0]), ..., ind]
return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha)

def effect_inference(self, X, *, T0, T1):
X, T0, T1 = self._est._expand_treatments(X, T0, T1)
Expand All @@ -434,9 +448,10 @@ def effect_inference(self, X, *, T0, T1):
pred_stderr = np.repeat(pred_stderr, T0.shape[0], axis=0) if pred_stderr is not None else None
pred = pred[np.arange(T0.shape[0]), ..., ind]
pred_stderr = pred_stderr[np.arange(T0.shape[0]), ..., ind] if pred_stderr is not None else None

# d_t=None here since we measure the effect across all Ts
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr,
pred_stderr=pred_stderr, mean_pred_stderr=None,
inf_type='effect',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())
Expand All @@ -449,6 +464,21 @@ class LinearModelFinalInferenceDiscrete(GenericModelFinalInferenceDiscrete):
based on the corresponding methods of the underlying model_final estimator.
"""

def const_marginal_effect_inference(self, X):
res_inf = super().const_marginal_effect_inference(X)

# set the mean_pred_stderr
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)

if hasattr(self.fitted_models_final[0], 'prediction_stderr'):
mean_X = X.mean(axis=0).reshape(1, -1) if X is not None else None
mean_pred_stderr = np.moveaxis(np.array([mdl.prediction_stderr(mean_X).reshape((-1,) + self._d_y)
for mdl in self.fitted_models_final]),
0, -1) # shape[0] will always be 1 here
res_inf.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
return res_inf

def coef__interval(self, T, *, alpha=0.1):
_, T = self._est._expand_treatments(None, T)
ind = inverse_onehot(T).item() - 1
Expand All @@ -472,8 +502,10 @@ def coef__inference(self, T):
fname_transformer = None
if hasattr(self._est, 'cate_feature_names') and callable(self._est.cate_feature_names):
fname_transformer = self._est.cate_feature_names

# d_t=None here since we measure the effect across all Ts
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=coef, pred_stderr=coef_stderr,
mean_pred_stderr=None,
inf_type='coefficient', fname_transformer=fname_transformer,
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())
Expand All @@ -500,7 +532,7 @@ def intercept__inference(self, T):
intercept_stderr = None
# d_t=None here since we measure the effect across all Ts
return NormalInferenceResults(d_t=None, d_y=self.d_y, pred=self.fitted_models_final[ind].intercept_,
pred_stderr=intercept_stderr,
pred_stderr=intercept_stderr, mean_pred_stderr=None,
inf_type='intercept',
feature_names=self._est.cate_feature_names(),
output_names=self._est.cate_output_names())
Expand Down Expand Up @@ -748,7 +780,6 @@ def summary_frame(self, alpha=0.1, value=0, decimals=3,

elif self.inf_type == 'intercept':
res.index = res.index.set_levels(['cate_intercept'], level="X")

if self._d_t == 1:
res.index = res.index.droplevel("T")
if self.d_y == 1:
Expand Down Expand Up @@ -786,6 +817,7 @@ def population_summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_n
output_names = self.output_names if output_names is None else output_names
if self.inf_type == 'effect':
return PopulationSummaryResults(pred=self.point_estimate, pred_stderr=self.stderr,
mean_pred_stderr=None,
d_t=self.d_t, d_y=self.d_y,
alpha=alpha, value=value, decimals=decimals, tol=tol,
output_names=output_names, treatment_names=treatment_names)
Expand Down Expand Up @@ -839,16 +871,21 @@ class NormalInferenceResults(InferenceResults):
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions should be collapsed
(e.g. if both are vectors, then the input of this argument will also be a vector)
mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
The standard error of the mean point estimate, this is derived from coefficient stderr when final
stage is linear model, otherwise it's None.
This is the exact standard error of the mean, which is not conservative.
inf_type: string
The type of inference result.
It could be either 'effect', 'coefficient' or 'intercept'.
fname_transformer: None or predefined function
The transform function to get the corresponding feature names from featurizer
"""

def __init__(self, d_t, d_y, pred, pred_stderr, inf_type, fname_transformer=None,
def __init__(self, d_t, d_y, pred, pred_stderr, mean_pred_stderr, inf_type, fname_transformer=None,
feature_names=None, output_names=None, treatment_names=None):
self.pred_stderr = pred_stderr
self.mean_pred_stderr = mean_pred_stderr
super().__init__(d_t, d_y, pred, inf_type, fname_transformer, feature_names, output_names, treatment_names)

@property
Expand Down Expand Up @@ -914,11 +951,20 @@ def pvalue(self, value=0):
"""
return norm.sf(np.abs(self.zstat(value)), loc=0, scale=1) * 2

def population_summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None):
pop_summ = super().population_summary(alpha=alpha, value=value, decimals=decimals,
tol=tol, output_names=output_names, treatment_names=treatment_names)
pop_summ.mean_pred_stderr = self.mean_pred_stderr
return pop_summ
population_summary.__doc__ = InferenceResults.population_summary.__doc__

def _expand_outputs(self, n_rows):
assert shape(self.pred)[0] == shape(self.pred_stderr)[0] == 1
pred = np.repeat(self.pred, n_rows, axis=0)
pred_stderr = np.repeat(self.pred_stderr, n_rows, axis=0) if self.pred_stderr is not None else None
return NormalInferenceResults(self.d_t, self.d_y, pred, pred_stderr, self.inf_type,
return NormalInferenceResults(self.d_t, self.d_y, pred, pred_stderr,
self.mean_pred_stderr,
self.inf_type,
self.fname_transformer, self.feature_names,
self.output_names, self.treatment_names)

Expand Down Expand Up @@ -1038,6 +1084,10 @@ class PopulationSummaryResults:
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions should be collapsed
(e.g. if both are vectors, then the input of this argument will also be a vector)
mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
The standard error of the mean point estimate, this is derived from coefficient stderr when final
stage is linear model, otherwise it's None.
This is the exact standard error of the mean, which is not conservative.
alpha: optional float in [0, 1] (default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
Expand All @@ -1054,10 +1104,11 @@ class PopulationSummaryResults:

"""

def __init__(self, pred, pred_stderr, d_t, d_y, alpha, value, decimals, tol,
def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha, value, decimals, tol,
output_names=None, treatment_names=None):
self.pred = pred
self.pred_stderr = pred_stderr
self.mean_pred_stderr = mean_pred_stderr
self.d_t = d_t
# For effect summaries, d_t is None, but the result arrays behave as if d_t=1
self._d_t = d_t or 1
Expand Down Expand Up @@ -1105,7 +1156,9 @@ def stderr_mean(self):
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
if self.pred_stderr is None:
if self.mean_pred_stderr is not None:
return self.mean_pred_stderr
elif self.pred_stderr is None:
raise AttributeError("Only point estimates are available!")
return np.sqrt(np.mean(self.pred_stderr**2, axis=0))

Expand Down Expand Up @@ -1311,13 +1364,13 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
self._format_res(self.pvalue(value=value), decimals),
self._format_res(self.conf_int_mean(alpha=alpha)[0], decimals),
self._format_res(self.conf_int_mean(alpha=alpha)[1], decimals)))

if treatment_names is None:
treatment_names = ['T' + str(i) for i in range(self._d_t)]
if output_names is None:
output_names = ['Y' + str(i) for i in range(self.d_y)]

myheaders1 = ['mean_point', 'stderr_mean', 'zstat', 'pvalue', 'ci_mean_lower', 'ci_mean_upper']

mystubs = self._get_stub_names(self.d_y, self._d_t, treatment_names, output_names)
title1 = "Uncertainty of Mean Point Estimate"

Expand All @@ -1330,13 +1383,12 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name

smry = Summary()
smry.add_table(res1, myheaders1, mystubs, title1)
if self.pred_stderr is not None:
if self.pred_stderr is not None and self.mean_pred_stderr is None:
text1 = "Note: The stderr_mean is a conservative upper bound."
smry.add_extra_txt([text1])
smry.add_table(res2, myheaders2, mystubs, title2)

if self.pred_stderr is not None:

# 3. Total Variance of Point Estimate
res3 = np.hstack((self._format_res(self.stderr_point, self.decimals),
self._format_res(self.conf_int_point(alpha=alpha, tol=tol)[0],
Expand Down
Loading