Skip to content

Commit

Permalink
adjust effect inference for discrete when T1 is a constant or a list …
Browse files Browse the repository at this point in the history
…of constant
  • Loading branch information
heimengqi committed Feb 27, 2021
1 parent 5039d4d commit 4c28958
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
12 changes: 12 additions & 0 deletions econml/inference/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,18 @@ def const_marginal_effect_inference(self, X):
res_inf.mean_pred_stderr = np.squeeze(mean_pred_stderr, axis=0)
return res_inf

def effect_inference(self, X, *, T0, T1):
res_inf = super().effect_inference(X, T0=T0, T1=T1)

# replace the mean_pred_stderr if T1 and T0 is a constant or a constant of vector
_, _, T1 = self._est._expand_treatments(X, T0, T1)
ind = inverse_onehot(T1)
if len(set(ind)) == 1:
unique_ind = ind[0] - 1
mean_pred_stderr = self.const_marginal_effect_inference(X).mean_pred_stderr[..., unique_ind]
res_inf.mean_pred_stderr = mean_pred_stderr
return res_inf

def coef__interval(self, T, *, alpha=0.1):
_, T = self._est._expand_treatments(None, T)
ind = inverse_onehot(T).item() - 1
Expand Down
5 changes: 2 additions & 3 deletions econml/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,12 @@ def test_mean_pred_stderr(self):
for est in ests:
est.fit(Y, T, X=X, W=W)
assert est.const_marginal_effect_inference(X).population_summary().mean_pred_stderr is not None
# only is not None when T1 is a constant or a list of constant
assert est.effect_inference(X).population_summary().mean_pred_stderr is not None
if est.__class__.__name__ == "LinearDRLearner":
assert est.coef__inference(T=1).mean_pred_stderr is None
# can't get the exact stderr of the mean effect for discrete treatment
assert est.effect_inference(X).population_summary().mean_pred_stderr is None
else:
assert est.coef__inference().mean_pred_stderr is None
assert est.effect_inference(X).population_summary().mean_pred_stderr is not None

class _NoFeatNamesEst:
def __init__(self, cate_est):
Expand Down

0 comments on commit 4c28958

Please sign in to comment.