Skip to content

Commit

Permalink
Ensure groups work with DRIV
Browse files Browse the repository at this point in the history
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
  • Loading branch information
kbattocchi committed Aug 14, 2023
1 parent 96ab4b4 commit 2cbd48e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
39 changes: 39 additions & 0 deletions econml/tests/test_dmliv.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,42 @@ def true_fn(X):
ate_lb, ate_ub = est.effect_interval()
np.testing.assert_array_less(ate_lb, true_ate)
np.testing.assert_array_less(true_ate, ate_ub)

def test_groups(self):
n = 500
d_w = 10
d_x = 3
y = np.random.normal(size=(n,))
W = np.random.normal(size=(n, d_w))
X = np.random.normal(size=(n, d_x))
T = np.random.choice(["a", "b"], size=(n,))
Z = np.random.choice(["c", "d"], size=(n,))
groups = [i // 3 for i in range(n)]

est_list = [
OrthoIV(
projection=False,
discrete_treatment=True,
discrete_instrument=True,
),
OrthoIV(
projection=True,
discrete_treatment=True,
discrete_instrument=True,
),
DMLIV(
model_final=LinearRegression(fit_intercept=False),
discrete_treatment=True,
discrete_instrument=True,
),
NonParamDMLIV(
model_final=RandomForestRegressor(),
discrete_treatment=True,
discrete_instrument=True,
),
]

for est in est_list:
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
score = est.score(y, T, Z=Z, X=X, W=W, groups=groups)
eff = est.const_marginal_effect(X)
47 changes: 47 additions & 0 deletions econml/tests/test_driv.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,50 @@ def true_fn(X):
np.testing.assert_array_less(true_coef, coef_ub)
np.testing.assert_array_less(intercept_lb, 0)
np.testing.assert_array_less(0, intercept_ub)

def test_groups(self):
n = 500
d_w = 10
d_x = 3
y = np.random.normal(size=(n,))
W = np.random.normal(size=(n, d_w))
X = np.random.normal(size=(n, d_x))
T = np.random.choice(["a", "b"], size=(n,))
Z = np.random.choice(["c", "d"], size=(n,))
groups = [i // 3 for i in range(n)]

est_list = [
DRIV(
flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False),
model_final=StatsModelsLinearRegression(fit_intercept=False),
discrete_instrument=True,
discrete_treatment=True
),
LinearDRIV(
flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False),
discrete_instrument=True,
discrete_treatment=True
),
SparseLinearDRIV(
flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False),
fit_cate_intercept=True,
discrete_instrument=True,
discrete_treatment=True
),
ForestDRIV(
flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False),
discrete_instrument=True,
discrete_treatment=True
),
IntentToTreatDRIV(
flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False)
),
LinearIntentToTreatDRIV(
flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False)
)
]

for est in est_list:
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
score = est.score(y, T, Z=Z, X=X, W=W, groups=groups)
eff = est.const_marginal_effect(X)

0 comments on commit 2cbd48e

Please sign in to comment.