Skip to content

Commit

Permalink
Fix orthoforest impl; change notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed May 2, 2019
1 parent 126da6a commit 180c9ae
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions econml/ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,10 @@ def effect(self, X=None, T0=0, T1=1):
Theta : matrix , shape (n, d_y)
CATE on each outcome for each sample.
"""
if ndim(T0) == 0:
T0 = np.repeat(T0, 1 if X is None else shape(X)[0])
if ndim(T1) == 0:
T1 = np.repeat(T1, 1 if X is None else shape(X)[0])
if np.ndim(T0) == 0:
T0 = np.repeat(T0, 1 if X is None else np.shape(X)[0])
if np.ndim(T1) == 0:
T1 = np.repeat(T1, 1 if X is None else np.shape(X)[0])
T0 = self._check_treatment(T0)
T1 = self._check_treatment(T1)
T0_encoded = self._label_encoder.transform(T0)
Expand Down
12 changes: 6 additions & 6 deletions notebooks/Double Machine Learning Examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"source": [
"est = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor())\n",
"est.fit(Y, T, X, W)\n",
"te_pred = est.const_marginal_effect(X_test)"
"te_pred = est.effect(X_test)"
]
},
{
Expand All @@ -194,7 +194,7 @@
"source": [
"est = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor(),featurizer=PolynomialFeatures(degree=2))\n",
"est.fit(Y, T, X, W)\n",
"te_pred1=est.const_marginal_effect(X_test)"
"te_pred1=est.effect(X_test)"
]
},
{
Expand All @@ -212,7 +212,7 @@
"source": [
"est = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor(),model_final=LassoCV(),featurizer=PolynomialFeatures(degree=10))\n",
"est.fit(Y, T, X, W)\n",
"te_pred2=est.const_marginal_effect(X_test)"
"te_pred2=est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -348,7 +348,7 @@
"metadata": {},
"outputs": [],
"source": [
"te_pred = est.const_marginal_effect(X_test)"
"te_pred = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -666,7 +666,7 @@
"source": [
"est = DMLCateEstimator(model_y=RandomForestRegressor(),model_t=RandomForestRegressor())\n",
"est.fit(Y, T, X, W)\n",
"te_pred=est.const_marginal_effect(X_test)"
"te_pred=est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -846,7 +846,7 @@
"source": [
"est = DMLCateEstimator(model_y=MultiTaskElasticNetCV(cv=3),model_t=MultiTaskElasticNetCV(cv=3))\n",
"est.fit(Y, T, X, W)\n",
"te_pred=est.const_marginal_effect(X_test)"
"te_pred=est.effect(X_test)"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions notebooks/Orthogonal Random Forest Examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@
}
],
"source": [
"treatment_effects = est.const_marginal_effect(X_test)"
"treatment_effects = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -418,7 +418,7 @@
}
],
"source": [
"treatment_effects = est.const_marginal_effect(X_test)"
"treatment_effects = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -603,7 +603,7 @@
}
],
"source": [
"treatment_effects = est.const_marginal_effect(X_test)"
"treatment_effects = est.effect(X_test)"
]
},
{
Expand Down Expand Up @@ -1004,7 +1004,7 @@
"source": [
"import time\n",
"t0 = time.time()\n",
"te_pred = est.const_marginal_effect(X_test)\n",
"te_pred = est.effect(X_test)\n",
"print(time.time() - t0)"
]
},
Expand Down

0 comments on commit 180c9ae

Please sign in to comment.