Skip to content

Commit

Permalink
CATE validation - uplift uniform confidence bands (#840)
Browse files Browse the repository at this point in the history
Add support for multiplier bootstrap uniform confidence band error bars for uplift curves
  • Loading branch information
amarvenu authored Mar 19, 2024
1 parent ed4fe33 commit 27d3101
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 149 deletions.
14 changes: 14 additions & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ CATE Interpreters
econml.cate_interpreter.SingleTreeCateInterpreter
econml.cate_interpreter.SingleTreePolicyInterpreter

.. _validation_api:

CATE Validation
---------------

.. autosummary::
:toctree: _autosummary

econml.validate.DRTester
econml.validate.BLPEvaluationResults
econml.validate.CalibrationEvaluationResults
econml.validate.UpliftEvaluationResults
econml.validate.EvaluationResults

.. _scorers_api:

CATE Scorers
Expand Down
48 changes: 27 additions & 21 deletions econml/tests/test_drtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import scipy.stats as st
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor

from econml.validate.drtester import DRtester
from econml.validate.drtester import DRTester
from econml.dml import DML


Expand Down Expand Up @@ -70,7 +70,7 @@ def test_multi(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_binary(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -148,8 +148,8 @@ def test_binary(self):
self.assertRaises(ValueError, res.plot_toc, k)
else: # real treatment, k = 1
self.assertTrue(res.plot_cal(k) is not None)
self.assertTrue(res.plot_qini(k) is not None)
self.assertTrue(res.plot_toc(k) is not None)
self.assertTrue(res.plot_qini(k, 'ucb2') is not None)
self.assertTrue(res.plot_toc(k, 'ucb1') is not None)

self.assertLess(res_df.blp_pval.values[0], 0.05) # heterogeneity
self.assertGreater(res_df.cal_r_squared.values[0], 0) # good R2
Expand All @@ -171,7 +171,7 @@ def test_nuisance_val_fit(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -193,8 +193,8 @@ def test_nuisance_val_fit(self):
for kwargs in [{}, {'Xval': Xval}]:
with self.assertRaises(Exception) as exc:
my_dr_tester.evaluate_cal(kwargs)
self.assertTrue(
str(exc.exception) == "Must fit nuisance models on training sample data to use calibration test"
self.assertEqual(
str(exc.exception), "Must fit nuisance models on training sample data to use calibration test"
)

def test_exceptions(self):
Expand All @@ -212,7 +212,7 @@ def test_exceptions(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -223,11 +223,11 @@ def test_exceptions(self):
with self.assertRaises(Exception) as exc:
func()
if func.__name__ == 'evaluate_cal':
self.assertTrue(
str(exc.exception) == "Must fit nuisance models on training sample data to use calibration test"
self.assertEqual(
str(exc.exception), "Must fit nuisance models on training sample data to use calibration test"
)
else:
self.assertTrue(str(exc.exception) == "Must fit nuisances before evaluating")
self.assertEqual(str(exc.exception), "Must fit nuisances before evaluating")

my_dr_tester = my_dr_tester.fit_nuisance(
Xval, Dval, Yval, Xtrain, Dtrain, Ytrain
Expand All @@ -242,12 +242,12 @@ def test_exceptions(self):
with self.assertRaises(Exception) as exc:
func()
if func.__name__ == 'evaluate_blp':
self.assertTrue(
str(exc.exception) == "CATE predictions not yet calculated - must provide Xval"
self.assertEqual(
str(exc.exception), "CATE predictions not yet calculated - must provide Xval"
)
else:
self.assertTrue(str(exc.exception) ==
"CATE predictions not yet calculated - must provide both Xval, Xtrain")
self.assertEqual(str(exc.exception),
"CATE predictions not yet calculated - must provide both Xval, Xtrain")

for func in [
my_dr_tester.evaluate_cal,
Expand All @@ -256,19 +256,19 @@ def test_exceptions(self):
]:
with self.assertRaises(Exception) as exc:
func(Xval=Xval)
self.assertTrue(
str(exc.exception) == "CATE predictions not yet calculated - must provide both Xval, Xtrain")
self.assertEqual(
str(exc.exception), "CATE predictions not yet calculated - must provide both Xval, Xtrain")

cal_res = my_dr_tester.evaluate_cal(Xval, Xtrain)
self.assertGreater(cal_res.cal_r_squared[0], 0) # good R2

with self.assertRaises(Exception) as exc:
my_dr_tester.evaluate_uplift(metric='blah')
self.assertTrue(
str(exc.exception) == "Unsupported metric - must be one of ['toc', 'qini']"
self.assertEqual(
str(exc.exception), "Unsupported metric 'blah' - must be one of ['toc', 'qini']"
)

my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -278,5 +278,11 @@ def test_exceptions(self):
qini_res = my_dr_tester.evaluate_uplift(Xval, Xtrain)
self.assertLess(qini_res.pvals[0], 0.05)

with self.assertRaises(Exception) as exc:
qini_res.plot_uplift(tmt=1, err_type='blah')
self.assertEqual(
str(exc.exception), "Invalid error type 'blah'; must be one of [None, 'ucb2', 'ucb1']"
)

autoc_res = my_dr_tester.evaluate_uplift(Xval, Xtrain, metric='toc')
self.assertLess(autoc_res.pvals[0], 0.05)
6 changes: 4 additions & 2 deletions econml/validate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
A suite of validation methods for CATE models.
"""

from .drtester import DRtester
from .drtester import DRTester
from .results import BLPEvaluationResults, CalibrationEvaluationResults, UpliftEvaluationResults, EvaluationResults


__all__ = ['DRtester']
__all__ = ['DRTester',
'BLPEvaluationResults', 'CalibrationEvaluationResults', 'UpliftEvaluationResults', 'EvaluationResults']
Loading

0 comments on commit 27d3101

Please sign in to comment.