Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorder effect arguments; allow scalar treatments #49

Merged
merged 3 commits into from
May 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import abc
import numpy as np
from .utilities import tensordot, ndim, reshape
from .utilities import tensordot, ndim, reshape, shape


class BaseCateEstimator:
Expand Down Expand Up @@ -40,7 +40,7 @@ def fit(self, Y, T, X=None, W=None, Z=None):
pass

@abc.abstractmethod
def effect(self, T0, T1, X=None):
def effect(self, X=None, T0=0, T1=1):
"""
Calculate the heterogeneous treatment effect τ(·,·,·).

Expand Down Expand Up @@ -117,7 +117,7 @@ def const_marginal_effect(self, X=None):
"""
pass

def effect(self, T0, T1, X=None):
def effect(self, X=None, T0=0, T1=1):
"""
Calculate the heterogeneous treatment effect τ(·,·,·).

Expand All @@ -144,8 +144,15 @@ def effect(self, T0, T1, X=None):
"""
# TODO: what if input is sparse? - there's no equivalent to einsum,
# but tensordot can't be applied to this problem because we don't sum over m
dT = T1 - T0
# TODO: if T0 or T1 are scalars, we'll promote them to vectors;
# should it be possible to promote them to 2D arrays if that's what we saw during training?
eff = self.const_marginal_effect(X)
m = shape(eff)[0]
if ndim(T0) == 0:
T0 = np.repeat(T0, m)
if ndim(T1) == 0:
T1 = np.repeat(T1, m)
dT = T1 - T0
einsum_str = 'myt,mt->my'
if ndim(dT) == 1:
einsum_str = einsum_str.replace('t', '')
Expand Down
8 changes: 7 additions & 1 deletion econml/deepiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def calc_grad(t, x):

self._marginal_effect_model = Model([t_in, x_in], L.Lambda(lambda tx: calc_grad(*tx))([t_in, x_in]))

def effect(self, T0, T1, X=None):
def effect(self, X=None, T0=0, T1=1):
"""
Calculate the heterogeneous treatment effect τ(·,·,·).

Expand All @@ -384,6 +384,12 @@ def effect(self, T0, T1, X=None):
Note that when Y is a vector rather than a 2-dimensional array, the corresponding
singleton dimension will be collapsed (so this method will return a vector)
"""
if ndim(T0) == 0:
T0 = np.repeat(T0, 1 if X is None else shape(X)[0])
if ndim(T1) == 0:
T1 = np.repeat(T1, 1 if X is None else shape(X)[0])
if X is None:
X = np.empty((shape(T0)[0], 0))
return self._effect_model.predict([T1, X]) - self._effect_model.predict([T0, X])

def marginal_effect(self, T, X=None):
Expand Down
8 changes: 6 additions & 2 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,15 @@ def const_marginal_effect(self, X=None):
# need to override effect in case of discrete treatments
# TODO: should this logic be moved up to the LinearCateEstimator class and
# removed from here and from the OrthoForest implementation?
def effect(self, T0, T1, X):
def effect(self, X, T0=0, T1=1):
if ndim(T0) == 0:
T0 = np.repeat(T0, 1 if X is None else shape(X)[0])
if ndim(T1) == 0:
T1 = np.repeat(T1, 1 if X is None else shape(X)[0])
if self._discrete_treatment:
T0 = self._one_hot_encoder.transform(reshape(self._label_encoder.transform(T0), (-1, 1)))[:, 1:]
T1 = self._one_hot_encoder.transform(reshape(self._label_encoder.transform(T1), (-1, 1)))[:, 1:]
return super().effect(T0, T1, X)
return super().effect(X, T0, T1)

def score(self, Y, T, X=None, W=None):
if self._discrete_treatment:
Expand Down
8 changes: 6 additions & 2 deletions econml/ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def fit(self, Y, T, X, W=None):
# Call `fit` from parent class
return super(DiscreteTreatmentOrthoForest, self).fit(Y, T, X, W)

def effect(self, T0, T1, X=None):
def effect(self, X=None, T0=0, T1=1):
"""Calculate the heterogeneous linear CATE θ(·) between two treatment points.

Parameters
Expand All @@ -701,11 +701,15 @@ def effect(self, T0, T1, X=None):
Theta : matrix , shape (n, d_y)
CATE on each outcome for each sample.
"""
if np.ndim(T0) == 0:
T0 = np.repeat(T0, 1 if X is None else np.shape(X)[0])
if np.ndim(T1) == 0:
T1 = np.repeat(T1, 1 if X is None else np.shape(X)[0])
T0 = self._check_treatment(T0)
T1 = self._check_treatment(T1)
T0_encoded = self._label_encoder.transform(T0)
T1_encoded = self._label_encoder.transform(T1)
return super(DiscreteTreatmentOrthoForest, self).effect(T0_encoded, T1_encoded, X)
return super(DiscreteTreatmentOrthoForest, self).effect(X, T0_encoded, T1_encoded)

def _pointwise_effect(self, X_single):
"""
Expand Down
2 changes: 1 addition & 1 deletion econml/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_integration():
y_test, T_test, X_test = econml.dgp.dgp_counterfactual_data_multiple_treatments(
n_samples, n_cov, beta, effect, 5. * np.ones(n_treatments))
dml_r2score.append(dml_reg.score(np.concatenate((X_test, T_test), axis=1), y_test))
dml_te.append(dml_reg.effect(np.zeros((1, 0)), np.zeros((1, 1)), np.ones((1, 1))))
dml_te.append(dml_reg.effect(np.zeros((1, 1)), np.ones((1, 1)), np.zeros((1, 0))))

# Estimation with other methods for comparison
direct_reg1.fit(np.concatenate((X, T), axis=1), y)
Expand Down
10 changes: 5 additions & 5 deletions econml/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def test_with_econml(self):
self.assertEqual(np.shape(est.coef_), np.shape(bound))

# test that we can do the same thing with the results of a method, rather than an attribute
self.assertEqual(np.shape(est.effect(t, t2, x)), np.shape(bs.effect(t, t2, x)))
self.assertEqual(np.shape(est.effect(x, t, t2)), np.shape(bs.effect(x, t, t2)))

# test that we can get an interval for the same attribute for the bootstrap as the original,
# with the same shape for the lower and upper bounds
lower, upper = bs.effect_interval(t, t2, x)
lower, upper = bs.effect_interval(x, t, t2)
for bound in [lower, upper]:
self.assertEqual(np.shape(est.effect(t, t2, x)), np.shape(bound))
self.assertEqual(np.shape(est.effect(x, t, t2)), np.shape(bound))

# test that we can do the same thing once we provide percentile bounds
lower, upper = bs.effect_interval(t, t2, x, lower=10, upper=90)
lower, upper = bs.effect_interval(x, t, t2, lower=10, upper=90)
for bound in [lower, upper]:
self.assertEqual(np.shape(est.effect(t, t2, x)), np.shape(bound))
self.assertEqual(np.shape(est.effect(x, t, t2)), np.shape(bound))
13 changes: 8 additions & 5 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def test_cate_api(self):
est.fit(Y, T, X, W)
# just make sure we can call the marginal_effect and effect methods
est.marginal_effect(None, X)
est.effect(0, T, X)
est.effect(X, np.zeros_like(T), T)
est.score(Y, T, X, W)
if d_t == -1:
# for vector-valued T, verify that default scalar T0 and T1 work
est.effect(X)

def test_can_use_vectors(self):
"""Test that we can pass vectors for T and Y (not only 2-dimensional arrays)."""
Expand All @@ -65,9 +68,9 @@ def test_discrete_treatments(self):
# and having the treatments in non-lexicographic order,
# Should rule out some basic issues.
dml.fit(np.array([2, 3, 1, 3, 2, 1, 1, 1]), np.array([3, 2, 1, 2, 3, 1, 1, 1]), np.ones((8, 1)))
np.testing.assert_almost_equal(dml.effect(np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]),
np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]),
np.ones((9, 1))),
np.testing.assert_almost_equal(dml.effect(np.ones((9, 1)),
np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]),
np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])),
[0, 2, 1, -2, 0, -1, -1, 1, 0])
dml.score(np.array([2, 3, 1, 3, 2, 1, 1, 1]), np.array([3, 2, 1, 2, 3, 1, 1, 1]), np.ones((8, 1)))

Expand Down Expand Up @@ -147,4 +150,4 @@ def _test_sparse(n_p, d_w, n_r):

np.testing.assert_allclose(a, dml.coef_.reshape(-1))
eff = reshape(t * np.choose(np.tile(p, 2), a), (-1,))
np.testing.assert_allclose(eff, dml.effect(0, t, x))
np.testing.assert_allclose(eff, dml.effect(x, 0, t))
2 changes: 1 addition & 1 deletion econml/tests/test_two_stage_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_2sls(self):
np2sls = NonparametricTwoStageLeastSquares(HermiteFeatures(
dt), HermiteFeatures(dx), HermiteFeatures(dz), HermiteFeatures(dt, shift=1))
np2sls.fit(y, p, x, z)
effect = np2sls.effect(np.zeros(shape(p_fresh)), p_fresh, x_fresh)
effect = np2sls.effect(x_fresh, np.zeros(shape(p_fresh)), p_fresh)
losses.append(np.mean(np.square(p_fresh * x_fresh - effect)))
marg_effs.append(np2sls.marginal_effect(np.array([[0.3], [0.5], [0.7]]), np.array([[0.4], [0.6], [0.2]])))
print("losses: {}".format(losses))
Expand Down
6 changes: 5 additions & 1 deletion econml/two_stage_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fit(self, Y, T, X, Z):
self._model_Y.fit(_add_ones(cross_product(ft_T_hat, ft_X)), Y)
return self

def effect(self, T0, T1, X=None):
def effect(self, X=None, T0=0, T1=1):
"""
Calculate the heterogeneous treatment effect τ(·,·,·).

Expand All @@ -189,6 +189,10 @@ def effect(self, T0, T1, X=None):
singleton dimension will be collapsed (so this method will return a vector)

"""
if ndim(T0) == 0:
T0 = np.repeat(T0, 1 if X is None else shape(X)[0])
if ndim(T1) == 0:
T1 = np.repeat(T1, 1 if X is None else shape(X)[0])
if X is None:
X = np.empty((shape(T0)[0], 0))
assert shape(T0) == shape(T1)
Expand Down
10 changes: 5 additions & 5 deletions notebooks/Double Machine Learning Examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"source": [
"est = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor(),random_state=123)\n",
"est.fit(Y, T, X, W)\n",
"te_pred = est.const_marginal_effect(X_test)"
"te_pred = est.effect(X_test)"
]
},
{
Expand All @@ -194,7 +194,7 @@
"source": [
"est1 = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor(),featurizer=PolynomialFeatures(degree=3),random_state=123)\n",
"est1.fit(Y, T, X, W)\n",
"te_pred1=est1.const_marginal_effect(X_test)"
"te_pred1=est1.effect(X_test)"
]
},
{
Expand All @@ -212,7 +212,7 @@
"source": [
"est2 = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor(),model_final=Lasso(alpha=0.1),featurizer=PolynomialFeatures(degree=10),random_state=123)\n",
"est2.fit(Y, T, X, W)\n",
"te_pred2=est2.const_marginal_effect(X_test)"
"te_pred2=est2.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -443,7 +443,7 @@
"metadata": {},
"outputs": [],
"source": [
"te_pred = est.const_marginal_effect(X_test)"
"te_pred = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -761,7 +761,7 @@
"source": [
"est = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor())\n",
"est.fit(Y, T, X, W)\n",
"te_pred=est.const_marginal_effect(X_test)"
"te_pred=est.effect(X_test)"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions notebooks/Orthogonal Random Forest Examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@
}
],
"source": [
"treatment_effects = est.const_marginal_effect(X_test)"
"treatment_effects = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -418,7 +418,7 @@
}
],
"source": [
"treatment_effects = est.const_marginal_effect(X_test)"
"treatment_effects = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -603,7 +603,7 @@
}
],
"source": [
"treatment_effects = est.const_marginal_effect(X_test)"
"treatment_effects = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -1004,7 +1004,7 @@
"source": [
"import time\n",
"t0 = time.time()\n",
"te_pred = est.const_marginal_effect(X_test)\n",
"te_pred = est.effect(X_test)\n",
"print(time.time() - t0)"
]
},
Expand Down