Skip to content

Commit

Permalink
Bug fix: Linear regression CATE estimates were not shown even when ne…
Browse files Browse the repository at this point in the history
…ed_conditional_estimates is True (#1092)

* fixed bug where CATE is not returned by lr

Signed-off-by: Amit Sharma <amit_sharma@live.com>

* added test

Signed-off-by: Amit Sharma <amit_sharma@live.com>

* formatted file

Signed-off-by: Amit Sharma <amit_sharma@live.com>

---------

Signed-off-by: Amit Sharma <amit_sharma@live.com>
  • Loading branch information
amit-sharma authored Dec 3, 2023
1 parent 918efc6 commit 1d050f0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dowhy/causal_estimators/regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def estimate_effect(
self._target_units = target_units
self._treatment_value = treatment_value
self._control_value = control_value
if need_conditional_estimates is None:
need_conditional_estimates = self.need_conditional_estimates
# TODO make treatment_value and control value also as local parameters
# All treatments are set to the same constant value
effect_estimate = self._do(data, treatment_value) - self._do(data, control_value)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,38 @@ def test_graph_input_nx(self, beta, num_instruments, num_samples, num_treatments
all_nodes = model._graph.get_all_nodes(include_unobserved=False)
assert "Unobserved Confounders" not in all_nodes

@mark.parametrize(
["beta", "num_effect_modifiers", "num_samples"],
[
(10, 0, 100),
(10, 1, 100),
],
)
def test_cate_estimates_regression(self, beta, num_effect_modifiers, num_samples):
data = dowhy.datasets.linear_dataset(
beta=beta,
num_common_causes=2,
num_samples=num_samples,
num_treatments=1,
treatment_is_binary=True,
num_effect_modifiers=num_effect_modifiers,
)
model = CausalModel(
data=data["df"],
treatment=data["treatment_name"],
outcome=data["outcome_name"],
graph=data["gml_graph"],
test_significance=None,
)
identified_estimand = model.identify_effect()
linear_estimate = model.estimate_effect(
identified_estimand, method_name="backdoor.linear_regression", control_value=0, treatment_value=1
)
if num_effect_modifiers == 0:
assert linear_estimate.conditional_estimates is None
else:
assert linear_estimate.conditional_estimates is not None

@mark.parametrize(
["num_variables", "num_samples"],
[
Expand Down

0 comments on commit 1d050f0

Please sign in to comment.