Skip to content

Commit

Permalink
Remove deprecated positional arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Aug 14, 2021
1 parent b374abd commit cc8f0f9
Show file tree
Hide file tree
Showing 15 changed files with 80 additions and 123 deletions.
4 changes: 1 addition & 3 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,8 @@ def _prefit(self, Y, T, *args, only_final=False, **kwargs):

super()._prefit(Y, T, *args, **kwargs)

@_deprecate_positional("X, W, and Z should be passed by keyword only. In a future release "
"we will disallow passing X, W, and Z by position.", ['X', 'W', 'Z'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None, only_final=False, check_input=True):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand Down
4 changes: 1 addition & 3 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,7 @@ def _gen_ortho_learner_model_nuisance(self):
def _gen_ortho_learner_model_final(self):
return _ModelFinal(self._gen_rlearner_model_final())

@_deprecate_positional("X, and should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand Down
5 changes: 1 addition & 4 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,10 +704,7 @@ def tune(self, Y, T, *, X=None, W=None,
return self

# override only so that we can update the docstring to indicate support for `blb`

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down
16 changes: 4 additions & 12 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,7 @@ def _gen_rlearner_model_final(self):
return _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept, self._gen_featurizer(), False)

# override only so that we can update the docstring to indicate support for `LinearModelFinalInference`
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down Expand Up @@ -615,9 +613,7 @@ def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=False)

# override only so that we can update the docstring to indicate support for `StatsModelsInference`
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down Expand Up @@ -826,9 +822,7 @@ def _gen_model_final(self):
n_jobs=self.n_jobs,
random_state=self.random_state)

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down Expand Up @@ -1157,9 +1151,7 @@ def _gen_rlearner_model_final(self):

# override only so that we can update the docstring to indicate
# support for `GenericSingleTreatmentModelFinalInference`
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down
16 changes: 4 additions & 12 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,7 @@ def _gen_model_final(self):
def _gen_ortho_learner_model_final(self):
return _ModelFinal(self._gen_model_final(), self._gen_featurizer(), self.multitask_model_final)

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand Down Expand Up @@ -849,9 +847,7 @@ def _gen_model_final(self):
def _gen_ortho_learner_model_final(self):
return _ModelFinal(self._gen_model_final(), self._gen_featurizer(), False)

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand Down Expand Up @@ -1144,9 +1140,7 @@ def _gen_model_final(self):
def _gen_ortho_learner_model_final(self):
return _ModelFinal(self._gen_model_final(), self._gen_featurizer(), False)

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand Down Expand Up @@ -1440,9 +1434,7 @@ def _gen_model_final(self):
def _gen_ortho_learner_model_final(self):
return _ModelFinal(self._gen_model_final(), self._gen_featurizer(), False)

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand Down
4 changes: 1 addition & 3 deletions econml/dynamic/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,7 @@ def _strata(self, Y, T, X=None, W=None, Z=None,
# Required for bootstrap inference
return groups

@_deprecate_positional("X, and should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups,
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, sample_var=None, groups,
cache_values=False, inference='auto'):
"""Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand Down
4 changes: 1 addition & 3 deletions econml/iv/nnet/_deepiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,8 @@ def __init__(self, *,
self._second_stage_options = second_stage_options
super().__init__()

@_deprecate_positional("X and Z should be passed by keyword only. In a future release "
"we will disallow passing X and Z by position.", ['X', 'Z'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X, Z, *, inference=None):
def fit(self, Y, T, *, X, Z, inference=None):
"""Estimate the counterfactual model from data.
That is, estimate functions τ(·, ·, ·), ∂τ(·, ·).
Expand Down
4 changes: 1 addition & 3 deletions econml/iv/sieve/_tsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ def __init__(self, *,
self._model_Y = LinearRegression(fit_intercept=False)
super().__init__()

@_deprecate_positional("X, W, and Z should be passed by keyword only. In a future release "
"we will disallow passing X, W, and Z by position.", ['X', 'W', 'Z'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, Z, X=None, W=None, *, inference=None):
def fit(self, Y, T, *, Z, X=None, W=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·, ·, ·), ∂τ(·, ·).
Expand Down
16 changes: 4 additions & 12 deletions econml/metalearners/_metalearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def __init__(self, *,
self.categories = categories
super().__init__()

@_deprecate_positional("X should be passed by keyword only. In a future release "
"we will disallow passing X by position.", ['X'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X, *, inference=None):
def fit(self, Y, T, *, X, inference=None):
"""Build an instance of TLearner.
Parameters
Expand Down Expand Up @@ -131,10 +129,8 @@ def __init__(self, *,
self.categories = categories
super().__init__()

@_deprecate_positional("X should be passed by keyword only. In a future release "
"we will disallow passing X by position.", ['X'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X=None, *, inference=None):
def fit(self, Y, T, *, X=None, inference=None):
"""Build an instance of SLearner.
Parameters
Expand Down Expand Up @@ -240,10 +236,8 @@ def __init__(self, *,
self.categories = categories
super().__init__()

@_deprecate_positional("X should be passed by keyword only. In a future release "
"we will disallow passing X by position.", ['X'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X, *, inference=None):
def fit(self, Y, T, *, X, inference=None):
"""Build an instance of XLearner.
Parameters
Expand Down Expand Up @@ -367,10 +361,8 @@ def __init__(self, *,
self.categories = categories
super().__init__()

@_deprecate_positional("X should be passed by keyword only. In a future release "
"we will disallow passing X by position.", ['X'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X, *, inference=None):
def fit(self, Y, T, *, X, inference=None):
"""Build an instance of DomainAdaptationLearner.
Parameters
Expand Down
12 changes: 3 additions & 9 deletions econml/orf/_ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,8 @@ def __init__(self,
self.categories = categories
super().__init__()

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X, W=None, *, inference='auto'):
def fit(self, Y, T, *, X, W=None, inference='auto'):
"""Build an orthogonal random forest from a training set (Y, T, X, W).
Parameters
Expand Down Expand Up @@ -610,9 +608,7 @@ def _combine(self, X, W):

# Need to redefine fit here for auto inference to work due to a quirk in how
# wrap_fit is defined
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X, W=None, *, inference='auto'):
def fit(self, Y, T, *, X, W=None, inference='auto'):
"""Build an orthogonal random forest from a training set (Y, T, X, W).
Parameters
Expand Down Expand Up @@ -949,9 +945,7 @@ def __init__(self, *,
batch_size=batch_size,
random_state=self.random_state)

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X, W=None, *, inference='auto'):
def fit(self, Y, T, *, X, W=None, inference='auto'):
"""Build an orthogonal random forest from a training set (Y, T, X, W).
Parameters
Expand Down
56 changes: 28 additions & 28 deletions econml/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_summary(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
summary_results = cate_est.summary()
Expand All @@ -64,8 +64,8 @@ def test_summary(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
Expand All @@ -78,8 +78,8 @@ def test_summary(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
summary_results = cate_est.summary()
Expand All @@ -90,8 +90,8 @@ def test_summary(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
Expand All @@ -104,8 +104,8 @@ def test_summary(self):
wrapped_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
summary_results = wrapped_est.summary()
Expand All @@ -117,8 +117,8 @@ def test_summary(self):
wrapped_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
Expand All @@ -138,8 +138,8 @@ def test_summary_discrete(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
summary_results = cate_est.summary(T=1)
Expand All @@ -159,8 +159,8 @@ def test_summary_discrete(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
Expand All @@ -174,8 +174,8 @@ def test_summary_discrete(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
summary_results = cate_est.summary(T=1)
Expand All @@ -187,8 +187,8 @@ def test_summary_discrete(self):
cate_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
Expand All @@ -202,8 +202,8 @@ def test_summary_discrete(self):
wrapped_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
summary_results = wrapped_est.summary(T=1)
Expand All @@ -216,8 +216,8 @@ def test_summary_discrete(self):
wrapped_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=inference
)
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
Expand Down Expand Up @@ -275,16 +275,16 @@ def test_can_summarize(self):
LinearDML(model_t=LinearRegression(), model_y=LinearRegression()).fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W
X=TestInference.X,
W=TestInference.W
).summary()

LinearDRLearner(model_regression=LinearRegression(),
model_propensity=LogisticRegression(), fit_cate_intercept=False).fit(
TestInference.Y,
TestInference.T > 0,
TestInference.X,
TestInference.W,
X=TestInference.X,
W=TestInference.W,
inference=BootstrapInference(5)
).summary(1)

Expand Down
Loading

0 comments on commit cc8f0f9

Please sign in to comment.