Skip to content

Commit 996d227

Browse files
authored
[FEA] Support Generalized Adjustment Criterion for Estimation plus Add Example Notebook (#1297)
* first commit Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding default case Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding minimal test Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * poe format Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding test, throwing on unsupported Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * tweaks Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * dependency bump Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * delete misc files Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fix dictionary mapping Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * make test check python version Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding another test Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding docs Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * restore notebooks I dont want to change Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * remove extraneous comment Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * remove comment and print statement from example notebook Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding estimation stage Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * styling Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fixes to true ate estimation Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * cleaning up test code a bit Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fix failing test Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fix broken test Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * format Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * more removing get_backdoor_variables Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fix bug Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * final tweaks Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * poe format Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * Fix NoneType exception Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fix todo Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * change comment Signed-off-by: Nicholas Parente <parentenickj@gmail.com> --------- Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
1 parent 3114151 commit 996d227

34 files changed

+1071
-143
lines changed

docs/source/example_notebooks/dowhy_generalized_covariate_adjustment_estimation_example.ipynb

+523
Large diffs are not rendered by default.

dowhy/causal_estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def __init__(self, identified_estimand, estimator_name):
960960
self.treatment_variable = identified_estimand.treatment_variable
961961
self.outcome_variable = identified_estimand.outcome_variable
962962
self.backdoor_variables = identified_estimand.get_backdoor_variables()
963+
self.general_adjustment_variables = identified_estimand.get_general_adjustment_variables()
963964
self.instrumental_variables = identified_estimand.instrumental_variables
964965
self.estimand_type = identified_estimand.estimand_type
965966
self.estimand_expression = None

dowhy/causal_estimators/causalml.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def fit(
118118
self._set_effect_modifiers(data, effect_modifier_names)
119119

120120
# Check the backdoor variables being used
121-
self.logger.debug("Back-door variables used:" + ",".join(self._target_estimand.get_backdoor_variables()))
121+
self.logger.debug("Adjustment set variables used:" + ",".join(self._target_estimand.get_adjustment_set()))
122122

123123
# Add the observed confounders and one hot encode the categorical variables
124-
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables()
124+
self._observed_common_causes_names = self._target_estimand.get_adjustment_set()
125125
if self._observed_common_causes_names:
126126
# Get the data of the unobserved confounders
127127
self._observed_common_causes = data[self._observed_common_causes_names]
@@ -220,6 +220,6 @@ def construct_symbolic_estimator(self, estimand):
220220
expr = "b: " + ",".join(estimand.outcome_variable) + "~"
221221
# TODO we are conditioning on a postive treatment
222222
# TODO create an expression corresponding to each estimator used
223-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
223+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
224224
expr += "+".join(var_list)
225225
return expr

dowhy/causal_estimators/distance_matching_estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def fit(self, data: pd.DataFrame, effect_modifier_names: Optional[List[str]] = N
130130
self.logger.error(error_msg)
131131
raise Exception(error_msg)
132132

133-
self.logger.debug("Back-door variables used:" + ",".join(self._target_estimand.get_backdoor_variables()))
133+
self.logger.debug("Adjustment set variables used:" + ",".join(self._target_estimand.get_adjustment_set()))
134134

135-
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables()
135+
self._observed_common_causes_names = self._target_estimand.get_adjustment_set()
136136
if self._observed_common_causes_names:
137137
if self.exact_match_cols is not None:
138138
self._observed_common_causes_names = [
@@ -307,6 +307,6 @@ def estimate_effect(
307307

308308
def construct_symbolic_estimator(self, estimand):
309309
expr = "b: " + ", ".join(estimand.outcome_variable) + "~"
310-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
310+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
311311
expr += "+".join(var_list)
312312
return expr

dowhy/causal_estimators/econml.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def fit(
120120
self._econml_fit_params = kwargs
121121
self._fit_params = kwargs
122122

123-
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables().copy()
123+
self._observed_common_causes_names = self._target_estimand.get_adjustment_set().copy()
124124

125125
# Enforcing this ordering is necessary to feed through the propensity values from dataset
126126
self._observed_common_causes_names = [

dowhy/causal_estimators/generalized_linear_model_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def predict_fn(self, data: pd.DataFrame, model, features):
127127

128128
def construct_symbolic_estimator(self, estimand):
129129
expr = "b: " + ",".join(estimand.outcome_variable) + "~" + "Sigmoid("
130-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
130+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
131131
expr += "+".join(var_list)
132132
if self._effect_modifier_names:
133133
interaction_terms = [

dowhy/causal_estimators/linear_regression_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def fit(
8787

8888
def construct_symbolic_estimator(self, estimand):
8989
expr = "b: " + ",".join(estimand.outcome_variable) + "~"
90-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
90+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
9191
expr += "+".join(var_list)
9292
if self._effect_modifier_names:
9393
interaction_terms = [

dowhy/causal_estimators/propensity_score_estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def fit(
9696
self.reset_encoders() # Forget any existing encoders
9797
self._set_effect_modifiers(data, effect_modifier_names)
9898

99-
self.logger.debug("Back-door variables used:" + ",".join(self._target_estimand.get_backdoor_variables()))
100-
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables()
99+
self.logger.debug("Adjustment set variables used:" + ",".join(self._target_estimand.get_adjustment_set()))
100+
self._observed_common_causes_names = self._target_estimand.get_adjustment_set()
101101

102102
if self._observed_common_causes_names:
103103
self._observed_common_causes = data[self._observed_common_causes_names]

dowhy/causal_estimators/propensity_score_matching_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,6 @@ def estimate_effect(
180180
def construct_symbolic_estimator(self, estimand):
181181
expr = "b: " + ", ".join(estimand.outcome_variable) + "~"
182182
# TODO -- fix: we are actually conditioning on positive treatment (d=1)
183-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
183+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
184184
expr += "+".join(var_list)
185185
return expr

dowhy/causal_estimators/propensity_score_stratification_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,6 @@ def _get_strata(self, data: pd.DataFrame, num_strata, clipping_threshold):
264264
def construct_symbolic_estimator(self, estimand):
265265
expr = "b: " + ",".join(estimand.outcome_variable) + "~"
266266
# TODO -- fix: we are actually conditioning on positive treatment (d=1)
267-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
267+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
268268
expr += "+".join(var_list)
269269
return expr

dowhy/causal_estimators/propensity_score_weighting_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,6 @@ def estimate_effect(
263263
def construct_symbolic_estimator(self, estimand):
264264
expr = "b: " + ",".join(estimand.outcome_variable) + "~"
265265
# TODO -- fix: we are actually conditioning on positive treatment (d=1)
266-
var_list = estimand.treatment_variable + estimand.get_backdoor_variables()
266+
var_list = estimand.treatment_variable + estimand.get_adjustment_set()
267267
expr += "+".join(var_list)
268268
return expr

dowhy/causal_estimators/regression_estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def fit(
8787
self.reset_encoders() # Forget any existing encoders
8888
self._set_effect_modifiers(data, effect_modifier_names)
8989

90-
self.logger.debug("Back-door variables used:" + ",".join(self._target_estimand.get_backdoor_variables()))
91-
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables()
90+
self.logger.debug("Adjustment set variables used:" + ",".join(self._target_estimand.get_adjustment_set()))
91+
self._observed_common_causes_names = self._target_estimand.get_adjustment_set()
9292
if len(self._observed_common_causes_names) > 0:
9393
self._observed_common_causes = data[self._observed_common_causes_names]
9494
self._observed_common_causes = self._encode(self._observed_common_causes, "observed_common_causes")

dowhy/causal_identifier/identified_estimand.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,23 @@ def get_instrumental_variables(self):
8484

8585
def get_general_adjustment_variables(self, key: Optional[str] = None):
8686
"""Return a list containing general adjustment variables."""
87+
gav = self.general_adjustment_variables or {}
88+
return gav.get(self.default_adjustment_set_id if key is None else key, None)
89+
90+
def set_general_adjustment_variables(self, variables_arr: List, key: Optional[str] = None):
8791
if key is None:
88-
return self.general_adjustment_variables[self.default_adjustment_set_id]
89-
else:
90-
return self.general_adjustment_variables[key]
92+
key = self.identifier_method
93+
self.general_adjustment_variables[key] = variables_arr
94+
95+
def get_adjustment_set(self, key: Optional[str] = None):
96+
if self.identifier_method == "general_adjustment":
97+
return self.get_general_adjustment_variables(key)
98+
return self.get_backdoor_variables(key)
99+
100+
def set_adjustment_set(self, variables_arr: List, key: Optional[str] = None):
101+
if self.identifier_method == "general_adjustment":
102+
return self.set_general_adjustment_variables(variables_arr, key)
103+
return self.set_backdoor_variables(variables_arr, key)
91104

92105
def __deepcopy__(self, memo):
93106
return IdentifiedEstimand(

dowhy/causal_refuter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, data, identified_estimand, estimate, **kwargs):
5454
# Concatenate the confounders, instruments and effect modifiers
5555
try:
5656
self._variables_of_interest = (
57-
self._target_estimand.get_backdoor_variables()
57+
self._target_estimand.get_adjustment_set()
5858
+ self._target_estimand.instrumental_variables
5959
+ self._estimate.estimator._effect_modifier_names
6060
)

dowhy/causal_refuters/add_unobserved_common_cause.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def preprocess_observed_common_causes(
207207
no_common_causes_error_message: str,
208208
):
209209
"""
210-
Preprocesses backdoor variables (observed common causes) and returns the pre-processed matrix.
210+
Preprocesses adjustment variables (observed common causes) and returns the pre-processed matrix.
211211
212-
At least one backdoor (common cause) variable is required. Raises an exception if none present.
212+
At least one covariate (common cause) variable is required. Raises an exception if none present.
213213
214214
Preprocessing has two steps:
215215
1. Categorical encoding.
@@ -222,7 +222,7 @@ def preprocess_observed_common_causes(
222222
"""
223223

224224
# 1. Categorical encoding of relevant variables
225-
observed_common_causes_names = target_estimand.get_backdoor_variables()
225+
observed_common_causes_names = target_estimand.get_adjustment_set()
226226
if len(observed_common_causes_names) > 0:
227227
# The encoded data is only used to calculate a parameter, so the encoder can be discarded.
228228
observed_common_causes = data[observed_common_causes_names]

dowhy/causal_refuters/assess_overlap.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self, *args, **kwargs):
4141
"""
4242
super().__init__(*args, **kwargs)
4343
# TODO: Check that the target estimand has backdoor variables?
44+
# TODO: Add support for the general adjustment criterion.
4445
self._backdoor_vars = self._target_estimand.get_backdoor_variables()
4546
self._cat_feats = kwargs.pop("cat_feats", [])
4647
self._support_config = kwargs.pop("support_config", None)

dowhy/causal_refuters/bootstrap_refuter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def refute_bootstrap(
191191

192192
chosen_variables = choose_variables(
193193
required_variables,
194-
target_estimand.get_backdoor_variables()
194+
target_estimand.get_adjustment_set()
195195
+ target_estimand.instrumental_variables
196196
+ estimate.estimator._effect_modifier_names,
197197
)

dowhy/causal_refuters/dummy_outcome_refuter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class DummyOutcomeRefuter(CausalRefuter):
5555
then we can add an arbitrary function h(t) to the dummy outcome's
5656
generation process and then the causal effect becomes h(t=1)-h(t=0).
5757
58-
Note that this general procedure only works for the backdoor criterion.
58+
Note that this general procedure only works for covariate adjustment.
5959
6060
1. We find f(W) for a each value of treatment. That is, keeping the treatment
6161
constant, we fit a predictor to estimate the effect of confounders W on
@@ -108,7 +108,7 @@ class DummyOutcomeRefuter(CausalRefuter):
108108
* function argument: function ``pd.Dataframe -> np.ndarray``
109109
110110
It takes in a function that takes the input data frame as the input and outputs the outcome
111-
variable. This allows us to create an output varable that only depends on the covariates and does not depend
111+
variable. This allows us to create an output variable that only depends on the covariates and does not depend
112112
on the treatment variable.
113113
114114
* string argument
@@ -271,7 +271,7 @@ def refute_dummy_outcome(
271271
then we can add an arbitrary function h(t) to the dummy outcome's
272272
generation process and then the causal effect becomes h(t=1)-h(t=0).
273273
274-
Note that this general procedure only works for the backdoor criterion.
274+
Note that this general procedure only works for covariate adjustment.
275275
276276
1. We find f(W) for a each value of treatment. That is, keeping the treatment
277277
constant, we fit a predictor to estimate the effect of confounders W on
@@ -438,7 +438,7 @@ def refute_dummy_outcome(
438438
estimator_present = _has_estimator(transformation_list)
439439
chosen_variables = choose_variables(
440440
required_variables,
441-
target_estimand.get_backdoor_variables()
441+
target_estimand.get_adjustment_set()
442442
+ target_estimand.instrumental_variables
443443
+ estimate.estimator._effect_modifier_names,
444444
)

dowhy/causal_refuters/evalue_sensitivity_analyzer.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,13 @@ def benchmark(self, data: pd.DataFrame):
251251
new_lo = []
252252
new_hi = []
253253
observed_covariate_e_values = []
254-
backdoor_vars = self.estimand.get_backdoor_variables()
255-
for drop_var in backdoor_vars:
254+
covariates = self.estimand.get_adjustment_set()
255+
for drop_var in covariates:
256256

257257
# new estimator
258-
new_backdoor_vars = [var for var in backdoor_vars if var != drop_var]
258+
new_covariate_vars = [var for var in covariates if var != drop_var]
259259
new_estimand = copy.deepcopy(self.estimand)
260-
new_estimand.set_backdoor_variables(new_backdoor_vars)
260+
new_estimand.set_adjustment_set(new_covariate_vars)
261261
new_estimator = self.estimate.estimator.get_new_estimator_object(new_estimand)
262262
new_estimator.fit(
263263
self.data,
@@ -296,7 +296,7 @@ def benchmark(self, data: pd.DataFrame):
296296

297297
self.benchmarking_results = pd.DataFrame(
298298
{
299-
"dropped_covariate": backdoor_vars,
299+
"dropped_covariate": covariates,
300300
"converted_est": new_ests,
301301
"converted_lower_ci": new_lo,
302302
"converted_upper_ci": new_hi,

dowhy/causal_refuters/random_common_cause.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def refute_random_common_cause(
107107
"""
108108
logger.info("Refutation over {} simulated datasets, each with a random common cause added".format(num_simulations))
109109

110-
new_backdoor_variables = target_estimand.get_backdoor_variables() + ["w_random"]
110+
new_adjustment_variables = target_estimand.get_adjustment_set() + ["w_random"]
111111
identified_estimand = copy.deepcopy(target_estimand)
112112
# Adding a new backdoor variable to the identified estimand
113-
identified_estimand.set_backdoor_variables(new_backdoor_variables)
113+
identified_estimand.set_adjustment_set(new_adjustment_variables)
114114

115115
if isinstance(random_state, int):
116116
random_state = np.random.RandomState(seed=random_state)

0 commit comments

Comments
 (0)