Skip to content

Commit

Permalink
Ensure groups work with DRIV
Browse files Browse the repository at this point in the history
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
  • Loading branch information
kbattocchi committed Aug 14, 2023
1 parent a081a3e commit 063fc6a
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
51 changes: 51 additions & 0 deletions econml/tests/test_dmliv.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,54 @@ def true_fn(X):
ate_lb, ate_ub = est.effect_interval()
np.testing.assert_array_less(ate_lb, true_ate)
np.testing.assert_array_less(true_ate, ate_ub)

def test_groups(self):
n = 500
d_w = 10
d_x = 3
y = np.random.normal(size=(n,))
W = np.random.normal(size=(n, d_w))
X = np.random.normal(size=(n, d_x))
T = np.random.choice(["a", "b"], size=(n,))
Z = np.random.choice(["c", "d"], size=(n,))
groups = [i // 3 for i in range(n)]

est_list = [
OrthoIV(
projection=False,
discrete_treatment=True,
discrete_instrument=True,
model_y_xw=LinearRegression(),
model_t_xw=LogisticRegression(),
model_z_xw=LogisticRegression(),
),
OrthoIV(
projection=True,
discrete_treatment=True,
discrete_instrument=True,
model_y_xw=LinearRegression(),
model_t_xw=LogisticRegression(),
model_t_xwz=LogisticRegression(),
),
DMLIV(
model_final=LinearRegression(fit_intercept=False),
discrete_treatment=True,
discrete_instrument=True,
model_y_xw=LinearRegression(),
model_t_xw=LogisticRegression(),
model_t_xwz=LogisticRegression(),
),
NonParamDMLIV(
model_final=RandomForestRegressor(),
discrete_treatment=True,
discrete_instrument=True,
model_y_xw=LinearRegression(),
model_t_xw=LogisticRegression(),
model_t_xwz=LogisticRegression(),
),
]

for est in est_list:
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
score = est.score(y, T, Z=Z, X=X, W=W, groups=groups)
eff = est.const_marginal_effect(X)
60 changes: 60 additions & 0 deletions econml/tests/test_driv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pickle
from scipy import special
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression, LogisticRegression
import unittest


Expand Down Expand Up @@ -238,3 +239,62 @@ def true_fn(X):
np.testing.assert_array_less(true_coef, coef_ub)
np.testing.assert_array_less(intercept_lb, 0)
np.testing.assert_array_less(0, intercept_ub)

def test_groups(self):
n = 500
d_w = 10
d_x = 3
y = np.random.normal(size=(n,))
W = np.random.normal(size=(n, d_w))
X = np.random.normal(size=(n, d_x))
T = np.random.choice(["a", "b"], size=(n,))
Z = np.random.choice(["c", "d"], size=(n,))
groups = [i // 3 for i in range(n)]

est_list = [
DRIV(
discrete_instrument=True,
discrete_treatment=True,
model_y_xw=LinearRegression(),
model_t_xwz=LogisticRegression(),
model_t_xw=LogisticRegression(),
model_tz_xw=LogisticRegression(),
),
LinearDRIV(
discrete_instrument=True,
discrete_treatment=True,
model_y_xw=LinearRegression(),
model_t_xwz=LogisticRegression(),
model_t_xw=LogisticRegression(),
model_tz_xw=LogisticRegression(),
),
SparseLinearDRIV(
discrete_instrument=True,
discrete_treatment=True,
model_y_xw=LinearRegression(),
model_t_xwz=LogisticRegression(),
model_t_xw=LogisticRegression(),
model_tz_xw=LogisticRegression(),
),
ForestDRIV(
discrete_instrument=True,
discrete_treatment=True,
model_y_xw=LinearRegression(),
model_t_xwz=LogisticRegression(),
model_t_xw=LogisticRegression(),
model_tz_xw=LogisticRegression(),
),
IntentToTreatDRIV(
model_y_xw=LinearRegression(),
model_t_xwz=LogisticRegression()
),
LinearIntentToTreatDRIV(
model_y_xw=LinearRegression(),
model_t_xwz=LogisticRegression()
)
]

for est in est_list:
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
score = est.score(y, T, Z=Z, X=X, W=W, groups=groups)
eff = est.const_marginal_effect(X)

0 comments on commit 063fc6a

Please sign in to comment.