diff --git a/dowhy/causal_identifier/auto_identifier.py b/dowhy/causal_identifier/auto_identifier.py index 7d9622fdc..12d913a02 100644 --- a/dowhy/causal_identifier/auto_identifier.py +++ b/dowhy/causal_identifier/auto_identifier.py @@ -91,14 +91,13 @@ def __init__( backdoor_adjustment: BackdoorAdjustment = BackdoorAdjustment.BACKDOOR_DEFAULT, optimize_backdoor: bool = False, costs: Optional[List] = None, + # By default, we will just compute a minimal adjustment set covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT, ): self.estimand_type = estimand_type self.backdoor_adjustment = backdoor_adjustment self.optimize_backdoor = optimize_backdoor self.costs = costs - # By default, we will just compute a minimal adjustment set (since it can be - # quite lengthy to compute an exhaustive set) self.covariate_adjustment = covariate_adjustment self.logger = logging.getLogger(__name__) @@ -895,10 +894,10 @@ def identify_complete_adjustment_set( # In default case, we don't find all exhaustive adjustment sets adjustment_set = nx.algorithms.find_minimal_d_separator( graph_pbd, - action_nodes, - outcome_nodes, + set(action_nodes), + set(outcome_nodes), # Require the adjustment set to consist only of observed nodes - restricted=((set(graph.nodes) - set(pcp_nodes)) & set(observed_nodes)) + restricted=((set(graph_pbd.nodes) - set(pcp_nodes)) & set(observed_nodes)) ) if adjustment_set is None: logger.info("No adjustment sets found.") diff --git a/tests/causal_identifiers/test_complete_adjustment_identifier.py b/tests/causal_identifiers/test_complete_adjustment_identifier.py index b00057082..696e6a0f4 100644 --- a/tests/causal_identifiers/test_complete_adjustment_identifier.py +++ b/tests/causal_identifiers/test_complete_adjustment_identifier.py @@ -9,7 +9,7 @@ class TestGeneralAdjustmentIdentification(object): - def test_identify_general_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution): + def test_identify_exhaustive_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution): graph = example_complete_adjustment_graph_solution.graph expected_sets = example_complete_adjustment_graph_solution.exhaustive_adjustment_sets adjustment_set_results = identify_complete_adjustment_set( @@ -30,4 +30,23 @@ def test_identify_general_adjustment(self, example_complete_adjustment_graph_sol for s in expected_sets ) + def test_identify_minimal_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution): + graph = example_complete_adjustment_graph_solution.graph + expected_set = example_complete_adjustment_graph_solution.minimal_adjustment_sets[0] + adjustment_set_results = identify_complete_adjustment_set( + graph, + action_nodes=["X"], + outcome_nodes=["Y"], + observed_nodes=example_complete_adjustment_graph_solution.observed_nodes, + covariate_adjustment=CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT, + ) + adjustment_sets = [ + set(adjustment_set.get_variables()) + for adjustment_set in adjustment_set_results + if len(adjustment_set.get_variables()) > 0 + ] + + assert (len(expected_set) == 0 and len(adjustment_sets) == 0) or set(expected_set) in adjustment_sets + +