Skip to content

Commit

Permalink
update test cases + validate example notebook
Browse files Browse the repository at this point in the history
Signed-off-by: amarv <amarvenu@stanford.edu>
  • Loading branch information
amarvenu committed Jan 18, 2024
1 parent db8bc86 commit d24a4ac
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
4 changes: 2 additions & 2 deletions econml/tests/test_drtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import scipy.stats as st
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor

from econml.validate.drtester import DRtester
from validate.drtester import DRtester
from econml.dml import DML


Expand Down Expand Up @@ -279,7 +279,7 @@ def test_exceptions(self):
self.assertLess(qini_res.pvals[0], 0.05)

with self.assertRaises(Exception) as exc:
qini_res.plot_uplift(metric='blah')
qini_res.plot_uplift(tmt=1, err_type='blah')
self.assertTrue(
str(exc.exception) == "Invalid error type; must be one of [None, 'ucb2', 'ucb1']"
)
Expand Down
13 changes: 4 additions & 9 deletions econml/validate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def calc_uplift(
cate_preds_val: np.array,
dr_val: np.array,
percentiles: np.array,
metric: str
metric: str,

) -> Tuple[float, float, pd.DataFrame]:
"""
Helper function for uplift curve generation and coefficient calculation.
Expand Down Expand Up @@ -100,14 +101,8 @@ def calc_uplift(

toc_std[it] = np.sqrt(np.mean(toc_psi[it] ** 2) / n) # standard error of tau(q)

if dr_val.shape[0] > 1e6: # avoid computational issues if dataset too large
mboot = np.zeros((len(qs), 1000))
for it in range(1000):
w = np.random.normal(0, 1, size=(n,))
mboot[:, it] = (toc_psi / toc_std.reshape(-1, 1)) @ w / n
else:
w = np.random.normal(0, 1, size=(n, 1000))
mboot = (toc_psi / toc_std.reshape(-1, 1)) @ w / n
w = np.random.normal(0, 1, size=(n, 1000))
mboot = (toc_psi / toc_std.reshape(-1, 1)) @ w / n

max_mboot = np.max(np.abs(mboot), axis=0)
uniform_critical_value = np.percentile(max_mboot, 95)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/CATE validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
{
"data": {
"text/plain": [
"<econml.dml.dml.DML at 0x7f7d1bcdb700>"
"<econml.dml.dml.DML at 0x7f7b4836cd00>"
]
},
"execution_count": 4,
Expand Down

0 comments on commit d24a4ac

Please sign in to comment.