Skip to content

Commit

Permalink
adding test, throwing on unsupported
Browse files Browse the repository at this point in the history
Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
  • Loading branch information
nparent1 committed Dec 30, 2024
1 parent 9fdf085 commit f43f279
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 31 deletions.
2 changes: 1 addition & 1 deletion dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def identify_complete_adjustment_set(
return []
return [AdjustmentSet(AdjustmentSet.GENERAL, adjustment_set)]

return [AdjustmentSet(AdjustmentSet.GENERAL, [])]
raise ValueError("Exhaustive identification of general adjustment sets is not yet supported.")


def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str]):
Expand Down
12 changes: 7 additions & 5 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def get_descendants(graph: nx.DiGraph, nodes):

def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes):
# Process is described in Van Der Zander et al. "Constructing Separators and
# Adjustment Sets in Ancestral Graphs", Section 4.1
# Adjustment Sets in Ancestral Graphs", Section 4.1.

# We cannot user do_surgery() since we require deep copies of the given graph.

# 1) Create modified graphs removing inbound and outbound arrows from the action nodes, respectively.
graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes
Expand All @@ -216,17 +218,17 @@ def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes)

# 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the
# action nodes to the outcome nodes.
de_x = get_descendants(graph_post_interv, action_nodes)
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes)
de_x = get_descendants(graph_post_interv, action_nodes).union(action_nodes)
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes).union(outcome_nodes)
return (set(de_x) - set(action_nodes)) & an_y


def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
# Process is described in Van Der Zander et al. "Constructing Separators and
# Adjustment Sets in Ancestral Graphs", Section 4.1
# Adjustment Sets in Ancestral Graphs", Section 4.1.

# First we can just call get_proper_causal_path_nodes, then
# we remove edges from the action_nodes to the proper causal path nodes
# we remove edges from the action_nodes to the proper causal path nodes.
graph_pbd = copy.deepcopy(graph)
graph_pbd.remove_edges_from(
[(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)]
Expand Down
6 changes: 4 additions & 2 deletions tests/causal_identifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def __init__(
self,
graph_str,
observed_variables,
action_nodes,
outcome_nodes,
minimal_adjustment_sets,
exhaustive_adjustment_sets,
):
self.graph = build_graph_from_str(graph_str)
self.action_nodes = ["X"]
self.outcome_nodes = ["Y"]
self.action_nodes = action_nodes
self.outcome_nodes = outcome_nodes
self.observed_nodes = observed_variables
self.minimal_adjustment_sets = minimal_adjustment_sets
self.exhaustive_adjustment_sets = exhaustive_adjustment_sets
Expand Down
14 changes: 13 additions & 1 deletion tests/causal_identifiers/example_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,21 @@
"shpitser_simple_non_backdoor_adjustment_set": dict(
graph_str="digraph{Z;X;Y; X->Z;X->Y}",
observed_variables=["Z", "X", "Y"],
action_nodes=["X"],
outcome_nodes=["Y"],
minimal_adjustment_sets=[{}],
exhaustive_adjustment_sets=[{"Z"}, {}],
)
),
# Example is selected from van der Zander et al. "Constructing Separators and Adjustment Sets in Ancestral
# Graphs", figure 2.
"van_der_zander_minimal_non_backdoor_adjustment_set": dict(
graph_str="digraph{Z1;Z2;X1;X2;Y1;Y2; X1->Y1;X1->Z1;Z1->Z2;Z2->X2;Y2->Z2}",
observed_variables=["Z1", "Z2", "X1", "X2", "Y1", "Y2"],
action_nodes=["X1", "X2"],
outcome_nodes=["Y1", "Y2"],
minimal_adjustment_sets=[{"Z1", "Z2"}],
exhaustive_adjustment_sets=["Z1", "Z2"],
),
}


Expand Down
24 changes: 2 additions & 22 deletions tests/causal_identifiers/test_complete_adjustment_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,15 @@


class TestGeneralAdjustmentIdentification(object):
def test_identify_exhaustive_adjustment(
self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution
):
graph = example_complete_adjustment_graph_solution.graph
expected_sets = example_complete_adjustment_graph_solution.exhaustive_adjustment_sets
adjustment_set_results = identify_complete_adjustment_set(
graph,
action_nodes=["X"],
outcome_nodes=["Y"],
observed_nodes=example_complete_adjustment_graph_solution.observed_nodes,
covariate_adjustment=CovariateAdjustment.COVARIATE_ADJUSTMENT_EXHAUSTIVE,
)
adjustment_sets = [
set(adjustment_set.get_variables())
for adjustment_set in adjustment_set_results
if len(adjustment_set.get_variables()) > 0
]

assert all((len(s) == 0 and len(adjustment_sets) == 0) or set(s) in adjustment_sets for s in expected_sets)

def test_identify_minimal_adjustment(
self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution
):
graph = example_complete_adjustment_graph_solution.graph
expected_set = example_complete_adjustment_graph_solution.minimal_adjustment_sets[0]
adjustment_set_results = identify_complete_adjustment_set(
graph,
action_nodes=["X"],
outcome_nodes=["Y"],
action_nodes=example_complete_adjustment_graph_solution.action_nodes,
outcome_nodes=example_complete_adjustment_graph_solution.outcome_nodes,
observed_nodes=example_complete_adjustment_graph_solution.observed_nodes,
covariate_adjustment=CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT,
)
Expand Down

0 comments on commit f43f279

Please sign in to comment.