Skip to content

Commit

Permalink
adding minimal test
Browse files Browse the repository at this point in the history
Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
  • Loading branch information
nparent1 committed Dec 30, 2024
1 parent 5f3bc5b commit ed55a3e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
9 changes: 4 additions & 5 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,13 @@ def __init__(
backdoor_adjustment: BackdoorAdjustment = BackdoorAdjustment.BACKDOOR_DEFAULT,
optimize_backdoor: bool = False,
costs: Optional[List] = None,
# By default, we will just compute a minimal adjustment set
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT,
):
self.estimand_type = estimand_type
self.backdoor_adjustment = backdoor_adjustment
self.optimize_backdoor = optimize_backdoor
self.costs = costs
# By default, we will just compute a minimal adjustment set (since it can be
# quite lengthy to compute an exhaustive set)
self.covariate_adjustment = covariate_adjustment
self.logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -895,10 +894,10 @@ def identify_complete_adjustment_set(
# In default case, we don't find all exhaustive adjustment sets
adjustment_set = nx.algorithms.find_minimal_d_separator(
graph_pbd,
action_nodes,
outcome_nodes,
set(action_nodes),
set(outcome_nodes),
# Require the adjustment set to consist only of observed nodes
restricted=((set(graph.nodes) - set(pcp_nodes)) & set(observed_nodes))
restricted=((set(graph_pbd.nodes) - set(pcp_nodes)) & set(observed_nodes))
)
if adjustment_set is None:
logger.info("No adjustment sets found.")
Expand Down
21 changes: 20 additions & 1 deletion tests/causal_identifiers/test_complete_adjustment_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class TestGeneralAdjustmentIdentification(object):
def test_identify_general_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution):
def test_identify_exhaustive_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution):
graph = example_complete_adjustment_graph_solution.graph
expected_sets = example_complete_adjustment_graph_solution.exhaustive_adjustment_sets
adjustment_set_results = identify_complete_adjustment_set(
Expand All @@ -30,4 +30,23 @@ def test_identify_general_adjustment(self, example_complete_adjustment_graph_sol
for s in expected_sets
)

def test_identify_minimal_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution):
graph = example_complete_adjustment_graph_solution.graph
expected_set = example_complete_adjustment_graph_solution.minimal_adjustment_sets[0]
adjustment_set_results = identify_complete_adjustment_set(
graph,
action_nodes=["X"],
outcome_nodes=["Y"],
observed_nodes=example_complete_adjustment_graph_solution.observed_nodes,
covariate_adjustment=CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT,
)
adjustment_sets = [
set(adjustment_set.get_variables())
for adjustment_set in adjustment_set_results
if len(adjustment_set.get_variables()) > 0
]

assert (len(expected_set) == 0 and len(adjustment_sets) == 0) or set(expected_set) in adjustment_sets



0 comments on commit ed55a3e

Please sign in to comment.