Skip to content

Commit

Permalink
Improve discreteness handling, allow binary outcomes (py-why#816)
Browse files Browse the repository at this point in the history
Adds a binary_outcome keyword arg to most estimators, where if True then the outcome nuisance model will be a classifier.

Additionally add constraints to ensure nuisance model discreteness is handled appropriately by the user. 
If a nuisance model has a continuous target but a classifier is passed, then will raise an AttributeError.
Conversely, if a nuisance model has a discrete target but a regressor is passed, then a warning is issued.
  • Loading branch information
fverac authored and kbattocchi committed Jan 23, 2024
1 parent c9c522d commit 1943507
Show file tree
Hide file tree
Showing 21 changed files with 1,422 additions and 431 deletions.
66 changes: 46 additions & 20 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class in this module implements the general logic in a very versatile way
TreatmentExpansionMixin)
from .inference import BootstrapInference
from .utilities import (_deprecate_positional, check_input_arrays,
cross_product, filter_none_kwargs,
cross_product, filter_none_kwargs, strata_from_discrete_arrays,
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose)
from .sklearn_extensions.model_selection import ModelSelector

Expand Down Expand Up @@ -327,6 +327,9 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
Parameters
----------
discrete_outcome: bool
Whether the outcome should be treated as binary
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities
Expand Down Expand Up @@ -426,7 +429,7 @@ def _gen_ortho_learner_model_final(self):
np.random.seed(123)
X = np.random.normal(size=(100, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None,
est = OrthoLearner(cv=2, discrete_outcome=False, discrete_treatment=False, treatment_featurizer=None,
discrete_instrument=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])
Expand Down Expand Up @@ -484,7 +487,7 @@ def _gen_ortho_learner_model_final(self):
import scipy.special
T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_outcome=False, discrete_treatment=True, discrete_instrument=False,
treatment_featurizer=None, categories='auto', random_state=None)
est.fit(y, T, W=W)
Expand Down Expand Up @@ -516,11 +519,20 @@ def _gen_ortho_learner_model_final(self):
"""

def __init__(self, *,
discrete_treatment, treatment_featurizer,
discrete_instrument, categories, cv, random_state,
mc_iters=None, mc_agg='mean', allow_missing=False, use_ray=False, ray_remote_func_options=None):
self.actors = []
discrete_outcome,
discrete_treatment,
treatment_featurizer,
discrete_instrument,
categories,
cv,
random_state,
mc_iters=None,
mc_agg='mean',
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
self.cv = cv
self.discrete_outcome = discrete_outcome
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
self.discrete_instrument = discrete_instrument
Expand Down Expand Up @@ -616,20 +628,15 @@ def _subinds_check_none(self, var, inds):
def _strata(self, Y, T, X=None, W=None, Z=None,
sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, only_final=False, check_input=True):
arrs = []
if self.discrete_outcome:
arrs.append(Y)
if self.discrete_treatment:
arrs.append(T)
if self.discrete_instrument:
Z = LabelEncoder().fit_transform(np.ravel(Z))
arrs.append(Z)

if self.discrete_treatment:
enc = LabelEncoder()
T = enc.fit_transform(np.ravel(T))
if self.discrete_instrument:
return T + Z * len(enc.classes_)
else:
return T
elif self.discrete_instrument:
return Z
else:
return None
return strata_from_discrete_arrays(arrs)

def _prefit(self, Y, T, *args, only_final=False, **kwargs):

Expand Down Expand Up @@ -706,6 +713,20 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N

if not only_final:

if self.discrete_outcome:
self.outcome_transformer = LabelEncoder()
self.outcome_transformer.fit(Y)
if Y.shape[1:] and Y.shape[1] > 1:
raise ValueError(
f"Only one outcome variable is supported when discrete_outcome=True. Got Y of shape {Y.shape}")
if len(self.outcome_transformer.classes_) > 2:
raise AttributeError(
f"({len(self.outcome_transformer.classes_)} outcome classes detected. "
"Currently, only 2 outcome classes are allowed when discrete_outcome=True. "
f"Classes provided include {self.outcome_transformer.classes_[:5]}")
else:
self.outcome_transformer = None

if self.discrete_treatment:
categories = self.categories
if categories != 'auto':
Expand Down Expand Up @@ -865,7 +886,7 @@ def refit_final(self, inference=None):
def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):

# use a binary array to get stratified split in case of discrete treatment
stratify = self.discrete_treatment or self.discrete_instrument
stratify = self.discrete_treatment or self.discrete_instrument or self.discrete_outcome
strata = self._strata(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight, groups=groups)
if strata is None:
strata = T # always safe to pass T as second arg to split even if we're not actually stratifying
Expand All @@ -878,6 +899,9 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
if self.discrete_instrument:
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))

if self.discrete_outcome:
Y = self.outcome_transformer.transform(Y).reshape(-1, 1)

if self.cv == 1: # special case, no cross validation
folds = None
else:
Expand Down Expand Up @@ -1008,6 +1032,8 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
X, T = self._expand_treatments(X, T)
if self.z_transformer is not None:
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))
if self.discrete_outcome:
Y = self.outcome_transformer.transform(Y).reshape(-1, 1)
n_iters = len(self._models_nuisance)
n_splits = len(self._models_nuisance[0])

Expand Down
24 changes: 19 additions & 5 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class _RLearner(_OrthoLearner):
Parameters
----------
discrete_outcome: bool
Whether the outcome should be treated as binary
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities
Expand Down Expand Up @@ -249,7 +252,7 @@ def _gen_rlearner_model_final(self):
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
est = RLearner(cv=2, discrete_treatment=False,
est = RLearner(cv=2, discrete_outcome=False, discrete_treatment=False,
treatment_featurizer=None, categories='auto', random_state=None)
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
Expand Down Expand Up @@ -297,10 +300,21 @@ def _gen_rlearner_model_final(self):
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
"""

def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean', allow_missing=False,
use_ray=False, ray_remote_func_options=None):
super().__init__(discrete_treatment=discrete_treatment,
def __init__(self,
*,
discrete_outcome,
discrete_treatment,
treatment_featurizer,
categories,
cv,
random_state,
mc_iters=None,
mc_agg='mean',
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
super().__init__(discrete_outcome=discrete_outcome,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
Expand Down
49 changes: 36 additions & 13 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,35 @@ class CausalForestDML(_BaseDML):
Parameters
----------
model_y: estimator or 'auto', default 'auto'
The estimator for fitting the response to the features. Must implement
`fit` and `predict` methods.
If 'auto' :class:`.WeightedLassoCV`/:class:`.WeightedMultiTaskLassoCV` will be chosen.
model_t: estimator or 'auto', default 'auto'
The estimator for fitting the treatment to the features.
If estimator, it must implement `fit` and `predict` methods;
If 'auto', :class:`~sklearn.linear_model.LogisticRegressionCV` will be applied for discrete treatment,
and :class:`.WeightedLassoCV`/:class:`.WeightedMultiTaskLassoCV`
will be applied for continuous treatment.
model_y: estimator, {'linear', 'forest'}, list of str/estimator, or 'auto'
Determines how to fit the treatment to the features.
- If an estimator, will use the model as is for fitting.
- If str, will use model associated with the keyword.
- 'linear' - LogisticRegressionCV if discrete_outcome=True else WeightedLassoCVWrapper
- 'forest' - RandomForestClassifier if discrete_outcome=True else RandomForestRegressor
- If list, will perform model selection on the supplied list, which can be a mix of str and estimators, \
and then use the best estimator for fitting.
- If 'auto', model will select over linear and forest models
User-supplied estimators should support 'fit' and 'predict' methods,
and additionally 'predict_proba' if discrete_outcome=True.
model_t: estimator, {'linear', 'forest'}, list of str/estimator, or 'auto', default 'auto'
Determines how to fit the treatment to the features. str in a sentence
- If an estimator, will use the model as is for fitting.
- If str, will use model associated with the keyword.
- 'linear' - LogisticRegressionCV if discrete_treatment=True else WeightedLassoCVWrapper
- 'forest' - RandomForestClassifier if discrete_treatment=True else RandomForestRegressor
- If list, will perform model selection on the supplied list, which can be a mix of str and estimators, \
and then use the best estimator for fitting.
- If 'auto', model will select over linear and forest models
User-supplied estimators should support 'fit' and 'predict' methods,
and additionally 'predict_proba' if discrete_treatment=True.
featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
Expand All @@ -290,6 +308,9 @@ class CausalForestDML(_BaseDML):
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.
discrete_outcome: bool, default ``False``
Whether the outcome should be treated as binary
discrete_treatment: bool, default ``False``
Whether the treatment values should be treated as categorical, rather than continuous, quantities
Expand Down Expand Up @@ -588,6 +609,7 @@ def __init__(self, *,
model_t='auto',
featurizer=None,
treatment_featurizer=None,
discrete_outcome=False,
discrete_treatment=False,
categories='auto',
cv=2,
Expand Down Expand Up @@ -644,7 +666,8 @@ def __init__(self, *,
self.subforest_size = subforest_size
self.n_jobs = n_jobs
self.verbose = verbose
super().__init__(discrete_treatment=discrete_treatment,
super().__init__(discrete_outcome=discrete_outcome,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
Expand All @@ -668,7 +691,7 @@ def _gen_featurizer(self):
return clone(self.featurizer, safe=False)

def _gen_model_y(self):
return _make_first_stage_selector(self.model_y, False, self.random_state)
return _make_first_stage_selector(self.model_y, self.discrete_outcome, self.random_state)

def _gen_model_t(self):
return _make_first_stage_selector(self.model_t, self.discrete_treatment, self.random_state)
Expand Down
Loading

0 comments on commit 1943507

Please sign in to comment.