Skip to content

Commit

Permalink
Expose individual interventional outcomes from RegressionEstimator _d…
Browse files Browse the repository at this point in the history
…o operator (#1011) (#1018)

Explose individual interventional outcomes from RegressionEstimator _do operator.

Original implementation of _do for RegressionEstimators calculated individual outcomes for all dataframe rows, and then returned the mean(). However, for many use-cases it is helpful to have the individual outcomes for further analysis. This commit moves the existing implementation (without changes) to a new function interventional_outcomes() which returns all outcomes. The implementation of _do() now only calls mean() on the returned values. The behaviour of _do() is unchanged.

Signed-off-by: drawlinson <dave@agi.io>
Co-authored-by: drawlinson <dave@agi.io>
  • Loading branch information
drawlinson and drawlinson authored Sep 4, 2023
1 parent 7841719 commit 1ec93b8
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion dowhy/causal_estimators/regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,17 @@ def _build_features(self, data_df: pd.DataFrame, treatment_values=None):
features = sm.add_constant(features, has_constant="add") # to add an intercept term
return features

def _do(self, data_df: pd.DataFrame, treatment_val):
def interventional_outcomes(self, data_df: pd.DataFrame, treatment_val):
"""
Applies an intervention treatment_val to all rows in data_df, then uses self.model
to predict outcomes. If data_df is None, will use self._data instead.
If no model exists, one will be created. The outcomes of all samples are returned,
allowing analysis of individual predictions in counterfactual treatment scenarios.
:param data_df: data frame containing the data
:param treatment_val: value for the treatment variable
:returns: A list of outcome predictions.
"""

if data_df is None:
data_df = self._data
if not self.model:
Expand Down Expand Up @@ -210,4 +220,8 @@ def _do(self, data_df: pd.DataFrame, treatment_val):

new_features = self._build_features(data_df, treatment_values=interventional_treatment_2d)
interventional_outcomes = self.predict_fn(data_df, self.model, new_features)
return interventional_outcomes

def _do(self, data_df: pd.DataFrame, treatment_val):
interventional_outcomes = self.interventional_outcomes(data_df, treatment_val)
return interventional_outcomes.mean()

0 comments on commit 1ec93b8

Please sign in to comment.