From f43f2797c5ba92910c6573db87ab966eb998c731 Mon Sep 17 00:00:00 2001 From: Nicholas Parente Date: Mon, 30 Dec 2024 12:22:59 -0500 Subject: [PATCH] adding test, throwing on unsupported Signed-off-by: Nicholas Parente --- dowhy/causal_identifier/auto_identifier.py | 2 +- dowhy/graph.py | 12 ++++++---- tests/causal_identifiers/base.py | 6 +++-- tests/causal_identifiers/example_graphs.py | 14 ++++++++++- .../test_complete_adjustment_identifier.py | 24 ++----------------- 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/dowhy/causal_identifier/auto_identifier.py b/dowhy/causal_identifier/auto_identifier.py index ca1e2aae7..49ab9349e 100644 --- a/dowhy/causal_identifier/auto_identifier.py +++ b/dowhy/causal_identifier/auto_identifier.py @@ -909,7 +909,7 @@ def identify_complete_adjustment_set( return [] return [AdjustmentSet(AdjustmentSet.GENERAL, adjustment_set)] - return [AdjustmentSet(AdjustmentSet.GENERAL, [])] + raise ValueError("Exhaustive identification of general adjustment sets is not yet supported.") def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str]): diff --git a/dowhy/graph.py b/dowhy/graph.py index 0569a6451..c3829e79d 100644 --- a/dowhy/graph.py +++ b/dowhy/graph.py @@ -204,7 +204,9 @@ def get_descendants(graph: nx.DiGraph, nodes): def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes): # Process is described in Van Der Zander et al. "Constructing Separators and - # Adjustment Sets in Ancestral Graphs", Section 4.1 + # Adjustment Sets in Ancestral Graphs", Section 4.1. + + # We cannot user do_surgery() since we require deep copies of the given graph. # 1) Create modified graphs removing inbound and outbound arrows from the action nodes, respectively. graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes @@ -216,17 +218,17 @@ def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes) # 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the # action nodes to the outcome nodes. - de_x = get_descendants(graph_post_interv, action_nodes) - an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes) + de_x = get_descendants(graph_post_interv, action_nodes).union(action_nodes) + an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes).union(outcome_nodes) return (set(de_x) - set(action_nodes)) & an_y def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes): # Process is described in Van Der Zander et al. "Constructing Separators and - # Adjustment Sets in Ancestral Graphs", Section 4.1 + # Adjustment Sets in Ancestral Graphs", Section 4.1. # First we can just call get_proper_causal_path_nodes, then - # we remove edges from the action_nodes to the proper causal path nodes + # we remove edges from the action_nodes to the proper causal path nodes. graph_pbd = copy.deepcopy(graph) graph_pbd.remove_edges_from( [(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)] diff --git a/tests/causal_identifiers/base.py b/tests/causal_identifiers/base.py index 4f3b94376..3298f99a9 100644 --- a/tests/causal_identifiers/base.py +++ b/tests/causal_identifiers/base.py @@ -52,12 +52,14 @@ def __init__( self, graph_str, observed_variables, + action_nodes, + outcome_nodes, minimal_adjustment_sets, exhaustive_adjustment_sets, ): self.graph = build_graph_from_str(graph_str) - self.action_nodes = ["X"] - self.outcome_nodes = ["Y"] + self.action_nodes = action_nodes + self.outcome_nodes = outcome_nodes self.observed_nodes = observed_variables self.minimal_adjustment_sets = minimal_adjustment_sets self.exhaustive_adjustment_sets = exhaustive_adjustment_sets diff --git a/tests/causal_identifiers/example_graphs.py b/tests/causal_identifiers/example_graphs.py index 7e2c26f80..bc51b3edc 100644 --- a/tests/causal_identifiers/example_graphs.py +++ b/tests/causal_identifiers/example_graphs.py @@ -394,9 +394,21 @@ "shpitser_simple_non_backdoor_adjustment_set": dict( graph_str="digraph{Z;X;Y; X->Z;X->Y}", observed_variables=["Z", "X", "Y"], + action_nodes=["X"], + outcome_nodes=["Y"], minimal_adjustment_sets=[{}], exhaustive_adjustment_sets=[{"Z"}, {}], - ) + ), + # Example is selected from van der Zander et al. "Constructing Separators and Adjustment Sets in Ancestral + # Graphs", figure 2. + "van_der_zander_minimal_non_backdoor_adjustment_set": dict( + graph_str="digraph{Z1;Z2;X1;X2;Y1;Y2; X1->Y1;X1->Z1;Z1->Z2;Z2->X2;Y2->Z2}", + observed_variables=["Z1", "Z2", "X1", "X2", "Y1", "Y2"], + action_nodes=["X1", "X2"], + outcome_nodes=["Y1", "Y2"], + minimal_adjustment_sets=[{"Z1", "Z2"}], + exhaustive_adjustment_sets=["Z1", "Z2"], + ), } diff --git a/tests/causal_identifiers/test_complete_adjustment_identifier.py b/tests/causal_identifiers/test_complete_adjustment_identifier.py index f21600bec..033dba8dc 100644 --- a/tests/causal_identifiers/test_complete_adjustment_identifier.py +++ b/tests/causal_identifiers/test_complete_adjustment_identifier.py @@ -8,26 +8,6 @@ class TestGeneralAdjustmentIdentification(object): - def test_identify_exhaustive_adjustment( - self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution - ): - graph = example_complete_adjustment_graph_solution.graph - expected_sets = example_complete_adjustment_graph_solution.exhaustive_adjustment_sets - adjustment_set_results = identify_complete_adjustment_set( - graph, - action_nodes=["X"], - outcome_nodes=["Y"], - observed_nodes=example_complete_adjustment_graph_solution.observed_nodes, - covariate_adjustment=CovariateAdjustment.COVARIATE_ADJUSTMENT_EXHAUSTIVE, - ) - adjustment_sets = [ - set(adjustment_set.get_variables()) - for adjustment_set in adjustment_set_results - if len(adjustment_set.get_variables()) > 0 - ] - - assert all((len(s) == 0 and len(adjustment_sets) == 0) or set(s) in adjustment_sets for s in expected_sets) - def test_identify_minimal_adjustment( self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution ): @@ -35,8 +15,8 @@ def test_identify_minimal_adjustment( expected_set = example_complete_adjustment_graph_solution.minimal_adjustment_sets[0] adjustment_set_results = identify_complete_adjustment_set( graph, - action_nodes=["X"], - outcome_nodes=["Y"], + action_nodes=example_complete_adjustment_graph_solution.action_nodes, + outcome_nodes=example_complete_adjustment_graph_solution.outcome_nodes, observed_nodes=example_complete_adjustment_graph_solution.observed_nodes, covariate_adjustment=CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT, )