Skip to content

Commit

Permalink
poe format
Browse files Browse the repository at this point in the history
Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
  • Loading branch information
nparent1 committed Dec 30, 2024
1 parent ed55a3e commit 9fdf085
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 30 deletions.
29 changes: 17 additions & 12 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import copy
import itertools
import logging
from enum import Enum
from typing import Dict, List, Optional, Union
import copy

import networkx as nx
import sympy as sp
Expand All @@ -21,9 +21,9 @@
get_backdoor_paths,
get_descendants,
get_instruments,
has_directed_path,
get_proper_backdoor_graph,
get_proper_causal_path_nodes,
get_proper_backdoor_graph
has_directed_path,
)
from dowhy.utils.api import parse_state

Expand Down Expand Up @@ -51,11 +51,13 @@ class BackdoorAdjustment(Enum):
BACKDOOR_MIN_EFFICIENT = "efficient-minimal-adjustment"
BACKDOOR_MINCOST_EFFICIENT = "efficient-mincost-adjustment"


class CovariateAdjustment(Enum):
# Covariate adjustment method names
COVARIATE_ADJUSTMENT_DEFAULT = "default"
COVARIATE_ADJUSTMENT_EXHAUSTIVE = "exhaustive-search"


MAX_BACKDOOR_ITERATIONS = 100000

METHOD_NAMES = {
Expand Down Expand Up @@ -119,7 +121,7 @@ def identify_effect(
self.backdoor_adjustment,
self.optimize_backdoor,
self.costs,
self.covariate_adjustment
self.covariate_adjustment,
)

estimand.identifier = self
Expand Down Expand Up @@ -158,7 +160,7 @@ def identify_effect_auto(
backdoor_adjustment: BackdoorAdjustment = BackdoorAdjustment.BACKDOOR_DEFAULT,
optimize_backdoor: bool = False,
costs: Optional[List] = None,
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT,
) -> IdentifiedEstimand:
"""Main method that returns an identified estimand (if one exists).
Expand Down Expand Up @@ -198,7 +200,7 @@ def identify_effect_auto(
estimand_type,
costs,
conditional_node_names,
covariate_adjustment
covariate_adjustment,
)
elif estimand_type == EstimandType.NONPARAMETRIC_NDE:
return identify_nde_effect(
Expand Down Expand Up @@ -233,7 +235,7 @@ def identify_ate_effect(
estimand_type: EstimandType,
costs: List,
conditional_node_names: List[str] = None,
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT,
):
estimands_dict = {}
mediation_first_stage_confounders = None
Expand Down Expand Up @@ -308,7 +310,9 @@ def identify_ate_effect(
### 4. GENERAL ADJUSTMENT IDENTIFICATION
# This generalizes the backdoor criterion, identifying other valid covariate adjustment sets that might not
# satisfy the backdoor criterion.
adjustment_sets = identify_complete_adjustment_set(graph, action_nodes, outcome_nodes, observed_nodes, covariate_adjustment)
adjustment_sets = identify_complete_adjustment_set(
graph, action_nodes, outcome_nodes, observed_nodes, covariate_adjustment
)
logger.info("Number of general adjustment sets found: " + str(len(adjustment_sets)))
estimands_dict, adjusment_variables_dict = build_adjustment_set_estimands_dict(
action_nodes, outcome_nodes, observed_nodes, adjustment_sets, estimands_dict
Expand Down Expand Up @@ -762,7 +766,8 @@ def get_default_adjustment_set_id(
# Default set contains minimum possible number of instrumental variables, to prevent lowering variance in the treatment variable.
instrument_names = set(get_instruments(graph, action_nodes, outcome_nodes))
iv_count_dict = {
key: len(set(adjustment_set).intersection(instrument_names)) for key, adjustment_set in adjustment_sets_dict.items()
key: len(set(adjustment_set).intersection(instrument_names))
for key, adjustment_set in adjustment_sets_dict.items()
}
min_iv_count = min(iv_count_dict.values())
min_iv_keys = {key for key, iv_count in iv_count_dict.items() if iv_count == min_iv_count}
Expand All @@ -783,7 +788,7 @@ def build_adjustment_set_estimands_dict(
outcome_names: List[str],
observed_nodes: List[str],
adjustment_sets: List[AdjustmentSet],
estimands_dict: Dict
estimands_dict: Dict,
):
"""Build the final dict for adjustment sets by filtering unobserved variables if needed."""
adjustment_variables_dict = {}
Expand Down Expand Up @@ -884,7 +889,7 @@ def identify_complete_adjustment_set(
action_nodes: List[str],
outcome_nodes: List[str],
observed_nodes: List[str],
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT,
) -> List[AdjustmentSet]:

graph_pbd = get_proper_backdoor_graph(graph, action_nodes, outcome_nodes)
Expand All @@ -897,7 +902,7 @@ def identify_complete_adjustment_set(
set(action_nodes),
set(outcome_nodes),
# Require the adjustment set to consist only of observed nodes
restricted=((set(graph_pbd.nodes) - set(pcp_nodes)) & set(observed_nodes))
restricted=((set(graph_pbd.nodes) - set(pcp_nodes)) & set(observed_nodes)),
)
if adjustment_set is None:
logger.info("No adjustment sets found.")
Expand Down
2 changes: 1 addition & 1 deletion dowhy/causal_identifier/backdoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_backdoor_vars(self):
AdjustmentSet(
_type=AdjustmentSet.BACKDOOR,
variables=tuple(obj.find_set()),
num_paths_blocked_by_observed_nodes=obj.num_sets()
num_paths_blocked_by_observed_nodes=obj.num_sets(),
)
)

Expand Down
2 changes: 1 addition & 1 deletion dowhy/causal_identifier/identified_estimand.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
default_backdoor_id=None,
identifier_method=None,
no_directed_path=False,
default_adjustment_set_id=None
default_adjustment_set_id=None,
):
self.identifier = identifier
self.treatment_variable = parse_state(treatment_variable)
Expand Down
3 changes: 1 addition & 2 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""This module defines the fundamental interfaces and functions related to causal graphs."""

import copy
import itertools
import logging
import re
from abc import abstractmethod
from typing import Any, List, Protocol
import copy

import networkx as nx
from networkx.algorithms.dag import has_cycle
Expand Down Expand Up @@ -234,7 +234,6 @@ def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
return graph_pbd



def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"):
if dseparation_algo == "default":
if new_graph is None:
Expand Down
12 changes: 9 additions & 3 deletions tests/causal_identifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from dowhy.graph import build_graph_from_str

from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT
from .example_graphs import (
TEST_FRONTDOOR_GRAPH_SOLUTIONS,
TEST_GRAPH_SOLUTIONS,
TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT,
)


class IdentificationTestGraphSolution(object):
Expand Down Expand Up @@ -68,7 +72,9 @@ def example_graph_solution(request):
def example_frontdoor_graph_solution(request):
return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param])


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT.keys())
def example_complete_adjustment_graph_solution(request):
return IdentificationTestGeneralCovariateAdjustmentGraphSolution(**TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT[request.param])

return IdentificationTestGeneralCovariateAdjustmentGraphSolution(
**TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT[request.param]
)
19 changes: 8 additions & 11 deletions tests/causal_identifiers/test_complete_adjustment_identifier.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import pytest

from dowhy.causal_identifier import AutoIdentifier, CovariateAdjustment
from dowhy.causal_identifier.identify_effect import EstimandType

from dowhy.causal_identifier.auto_identifier import identify_complete_adjustment_set
from dowhy.causal_identifier.identify_effect import EstimandType

from .base import IdentificationTestGeneralCovariateAdjustmentGraphSolution, example_complete_adjustment_graph_solution


class TestGeneralAdjustmentIdentification(object):
def test_identify_exhaustive_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution):
def test_identify_exhaustive_adjustment(
self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution
):
graph = example_complete_adjustment_graph_solution.graph
expected_sets = example_complete_adjustment_graph_solution.exhaustive_adjustment_sets
adjustment_set_results = identify_complete_adjustment_set(
Expand All @@ -25,12 +26,11 @@ def test_identify_exhaustive_adjustment(self, example_complete_adjustment_graph_
if len(adjustment_set.get_variables()) > 0
]

assert all(
(len(s) == 0 and len(adjustment_sets) == 0) or set(s) in adjustment_sets
for s in expected_sets
)
assert all((len(s) == 0 and len(adjustment_sets) == 0) or set(s) in adjustment_sets for s in expected_sets)

def test_identify_minimal_adjustment(self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution):
def test_identify_minimal_adjustment(
self, example_complete_adjustment_graph_solution: IdentificationTestGeneralCovariateAdjustmentGraphSolution
):
graph = example_complete_adjustment_graph_solution.graph
expected_set = example_complete_adjustment_graph_solution.minimal_adjustment_sets[0]
adjustment_set_results = identify_complete_adjustment_set(
Expand All @@ -47,6 +47,3 @@ def test_identify_minimal_adjustment(self, example_complete_adjustment_graph_sol
]

assert (len(expected_set) == 0 and len(adjustment_sets) == 0) or set(expected_set) in adjustment_sets



0 comments on commit 9fdf085

Please sign in to comment.