Skip to content

Commit

Permalink
improve input processing for StatsModelsLinearRegression (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
fverac authored Apr 6, 2022
1 parent 3228508 commit 493f7cd
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
4 changes: 3 additions & 1 deletion econml/sklearn_extensions/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Iterable
from scipy.stats import norm
from econml.sklearn_extensions.model_selection import WeightedKFold, WeightedStratifiedKFold
from econml.utilities import ndim, shape, reshape, _safe_norm_ppf
from econml.utilities import ndim, shape, reshape, _safe_norm_ppf, check_input_arrays
from sklearn import clone
from sklearn.linear_model import LinearRegression, LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLasso
from sklearn.metrics import r2_score
Expand Down Expand Up @@ -1683,6 +1683,8 @@ def __init__(self, fit_intercept=True, cov_type="HC0"):
def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
"""Check dimensions and other assertions."""

X, y, sample_weight, freq_weight, sample_var = check_input_arrays(
X, y, sample_weight, freq_weight, sample_var, dtype='numeric')
if X is None:
X = np.empty((y.shape[0], 0))
if self.fit_intercept:
Expand Down
19 changes: 19 additions & 0 deletions econml/tests/test_statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,25 @@ def true_effect(x):
assert np.all(np.abs(est.intercept_ - lr.intercept_) <
1e-12), "{}, {}".format(est.intercept_, lr.intercept_)

def test_o_dtype(self):
""" Testing that the models still work when the np arrays are of O dtype """
np.random.seed(123)
n = 1000
d = 3

X = np.random.normal(size=(n, d)).astype('O')
y = np.random.normal(size=n).astype('O')

est = OLS().fit(X, y)
lr = LinearRegression().fit(X, y)
assert np.all(np.abs(est.coef_ - lr.coef_) < 1e-12), "{}, {}".format(est.coef_, lr.coef_)
assert np.all(np.abs(est.intercept_ - lr.intercept_) < 1e-12), "{}, {}".format(est.coef_, lr.intercept_)

est = OLS(fit_intercept=False).fit(X, y)
lr = LinearRegression(fit_intercept=False).fit(X, y)
assert np.all(np.abs(est.coef_ - lr.coef_) < 1e-12), "{}, {}".format(est.coef_, lr.coef_)
assert np.all(np.abs(est.intercept_ - lr.intercept_) < 1e-12), "{}, {}".format(est.coef_, lr.intercept_)

def test_inference(self):
""" Testing that we recover the expected standard errors and confidence intervals in a known example """

Expand Down
11 changes: 9 additions & 2 deletions econml/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True):
return Y, T, X, W


def check_input_arrays(*args, validate_len=True, force_all_finite=True):
def check_input_arrays(*args, validate_len=True, force_all_finite=True, dtype=None):
"""Cast input sequences into numpy arrays.
Only inputs that are sequence-like will be converted, all other inputs will be left as is.
Expand All @@ -531,6 +531,13 @@ def check_input_arrays(*args, validate_len=True, force_all_finite=True):
force_all_finite : bool (default=True)
Whether to allow inf and nan in input arrays.
dtype : 'numeric', type, list of type or None (default=None)
Argument passed to sklearn.utils.check_array.
Specifies data type of result. If None, the dtype of the input is preserved.
If "numeric", dtype is preserved unless array.dtype is object.
If dtype is a list of types, conversion on the first type is only
performed if the dtype of the input is not in the list.
Returns
-------
args: array-like
Expand All @@ -541,7 +548,7 @@ def check_input_arrays(*args, validate_len=True, force_all_finite=True):
args = list(args)
for i, arg in enumerate(args):
if np.ndim(arg) > 0:
new_arg = check_array(arg, dtype=None, ensure_2d=False, accept_sparse=True,
new_arg = check_array(arg, dtype=dtype, ensure_2d=False, accept_sparse=True,
force_all_finite=force_all_finite)
if not force_all_finite:
# For when checking input values is disabled
Expand Down

0 comments on commit 493f7cd

Please sign in to comment.