Skip to content

Commit

Permalink
Proposal: Finalize functional API refactor - deprecate causal graph (#…
Browse files Browse the repository at this point in the history
…943)

* Deprecate CausalGraph

The effect estimation API is now based on an functional API that expects a networkx graph as input.

- The graph should now be defined via a networkx graph. Most identification methods now expect an additional "observed_nodes" parameter accordingly.
- CausalModel and CausalGraph still exist and should be compatible with the old API.

---------

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
Signed-off-by: Amit Sharma <amit_sharma@live.com>
Co-authored-by: Amit Sharma <amit_sharma@live.com>
  • Loading branch information
bloebp and amit-sharma authored Nov 27, 2023
1 parent 4fd0a92 commit 2a8e49a
Show file tree
Hide file tree
Showing 50 changed files with 1,179 additions and 604 deletions.
16 changes: 10 additions & 6 deletions docs/source/example_notebooks/do_sampler_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"\n",
"## Integration\n",
"\n",
"The do-sampler is built on top of the identification abstraction used throughout do-why. It uses a `dowhy.CausalModel` to perform identification, and builds any models it needs automatically using this identification.\n",
"The do-sampler is built on top of the identification abstraction used throughout do-why. It automatically performs an identification, and builds any models it needs automatically using this identification.\n",
"\n",
"## Specifying Interventions\n",
"\n",
Expand Down Expand Up @@ -128,7 +128,8 @@
"model = CausalModel(df, \n",
" causes,\n",
" outcomes,\n",
" common_causes=common_causes)"
" common_causes=common_causes)\n",
"nx_graph = model._graph._graph"
]
},
{
Expand Down Expand Up @@ -162,8 +163,11 @@
"source": [
"from dowhy.do_samplers.weighting_sampler import WeightingSampler\n",
"\n",
"sampler = WeightingSampler(df,\n",
" causal_model=model,\n",
"sampler = WeightingSampler(graph=nx_graph,\n",
" action_nodes=causes,\n",
" outcome_nodes=outcomes,\n",
" observed_nodes=df.columns.tolist(),\n",
" data=df,\n",
" keep_original_treatment=True,\n",
" variable_types={'D': 'b', 'Z': 'c', 'Y': 'c'}\n",
" )\n",
Expand Down Expand Up @@ -207,7 +211,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -221,7 +225,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
Expand Down
31 changes: 17 additions & 14 deletions docs/source/example_notebooks/dowhy_causal_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"source": [
"import dowhy.datasets\n",
"import dowhy.api\n",
"from dowhy.graph import build_graph_from_str\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand All @@ -36,7 +37,7 @@
" treatment_is_binary=True)\n",
"df = data['df']\n",
"df['y'] = df['y'] + np.random.normal(size=len(df)) # Adding noise to data. Without noise, the variance in Y|X, Z is zero, and mcmc fails.\n",
"#data['dot_graph'] = 'digraph { v ->y;X0-> v;X0-> y;}'\n",
"nx_graph = build_graph_from_str(data[\"dot_graph\"])\n",
"\n",
"treatment= data[\"treatment_name\"][0]\n",
"outcome = data[\"outcome_name\"][0]\n",
Expand All @@ -47,15 +48,17 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# data['df'] is just a regular pandas.DataFrame\n",
"df.causal.do(x=treatment,\n",
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'},\n",
" outcome=outcome,\n",
" common_causes=[common_cause],\n",
" proceed_when_unidentifiable=True).groupby(treatment).mean().plot(y=outcome, kind='bar')"
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'},\n",
" outcome=outcome,\n",
" common_causes=[common_cause],\n",
" ).groupby(treatment).mean().plot(y=outcome, kind='bar')"
]
},
{
Expand All @@ -68,8 +71,8 @@
" variable_types={treatment:'b', outcome: 'c', common_cause: 'c'}, \n",
" outcome=outcome,\n",
" method='weighting', \n",
" common_causes=[common_cause],\n",
" proceed_when_unidentifiable=True).groupby(treatment).mean().plot(y=outcome, kind='bar')"
" common_causes=[common_cause]\n",
" ).groupby(treatment).mean().plot(y=outcome, kind='bar')"
]
},
{
Expand All @@ -81,14 +84,14 @@
"cdf_1 = df.causal.do(x={treatment: 1}, \n",
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'}, \n",
" outcome=outcome, \n",
" dot_graph=data['dot_graph'],\n",
" proceed_when_unidentifiable=True)\n",
" graph=nx_graph\n",
" )\n",
"\n",
"cdf_0 = df.causal.do(x={treatment: 0}, \n",
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'}, \n",
" outcome=outcome, \n",
" dot_graph=data['dot_graph'],\n",
" proceed_when_unidentifiable=True)\n"
" graph=nx_graph\n",
" )\n"
]
},
{
Expand Down Expand Up @@ -158,7 +161,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -172,7 +175,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
"outputs": [],
"source": [
"from dowhy.causal_graph import CausalGraph\n",
"from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType"
"from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType\n",
"from dowhy.graph import build_graph_from_str\n",
"from dowhy.utils.plotting import plot"
]
},
{
Expand Down Expand Up @@ -135,9 +137,7 @@
"]\n",
"treatment_name = \"warm-up\"\n",
"outcome_name = \"injury\"\n",
"G = CausalGraph(\n",
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
")"
"G = build_graph_from_str(graph_str)"
]
},
{
Expand All @@ -153,7 +153,7 @@
"metadata": {},
"outputs": [],
"source": [
"G.view_graph()"
"plot(G)"
]
},
{
Expand Down Expand Up @@ -184,7 +184,11 @@
")\n",
"print(\n",
" ident_eff.identify_effect(\n",
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
" graph=G, \n",
" action_nodes=treatment_name, \n",
" outcome_nodes=outcome_name,\n",
" observed_nodes=observed_node_names,\n",
" conditional_node_names=conditional_node_names\n",
" )\n",
")"
]
Expand Down Expand Up @@ -215,7 +219,11 @@
")\n",
"print(\n",
" ident_minimal_eff.identify_effect(\n",
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
" graph=G, \n",
" action_nodes=treatment_name, \n",
" outcome_nodes=outcome_name, \n",
" observed_nodes=observed_node_names,\n",
" conditional_node_names=conditional_node_names\n",
" )\n",
")"
]
Expand All @@ -239,7 +247,11 @@
")\n",
"print(\n",
" ident_mincost_eff.identify_effect(\n",
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
" graph=G, \n",
" action_nodes=treatment_name, \n",
" outcome_nodes=outcome_name,\n",
" observed_nodes=observed_node_names,\n",
" conditional_node_names=conditional_node_names\n",
" )\n",
")"
]
Expand Down Expand Up @@ -294,9 +306,7 @@
"observed_node_names = [\"X\", \"Y\", \"Z1\", \"Z2\"]\n",
"treatment_name = \"X\"\n",
"outcome_name = \"Y\"\n",
"G = CausalGraph(\n",
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
")"
"G = build_graph_from_str(graph_str)"
]
},
{
Expand All @@ -317,7 +327,10 @@
" backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EFFICIENT,\n",
")\n",
"try:\n",
" results_eff = ident_eff.identify_effect(graph=G, treatment_name=treatment_name, outcome_name=outcome_name)\n",
" results_eff = ident_eff.identify_effect(graph=G, \n",
" action_nodes=treatment_name, \n",
" outcome_nodes=outcome_name,\n",
" observed_nodes=observed_node_names)\n",
"except ValueError as e:\n",
" print(e)"
]
Expand All @@ -335,8 +348,9 @@
"print(\n",
" ident_minimal_eff.identify_effect(\n",
" graph=G,\n",
" treatment_name=treatment_name,\n",
" outcome_name=outcome_name,\n",
" action_nodes=treatment_name,\n",
" outcome_nodes=outcome_name,\n",
" observed_nodes=observed_node_names\n",
" )\n",
")"
]
Expand All @@ -354,8 +368,9 @@
"print(\n",
" ident_mincost_eff.identify_effect(\n",
" graph=G,\n",
" treatment_name=treatment_name,\n",
" outcome_name=outcome_name,\n",
" action_nodes=treatment_name,\n",
" outcome_nodes=outcome_name,\n",
" observed_nodes=observed_node_names\n",
" )\n",
")"
]
Expand Down Expand Up @@ -391,9 +406,7 @@
"observed_node_names = [\"X\", \"Y\"]\n",
"treatment_name = \"X\"\n",
"outcome_name = \"Y\"\n",
"G = CausalGraph(\n",
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
")"
"G = build_graph_from_str(graph_str)"
]
},
{
Expand All @@ -409,8 +422,9 @@
"try:\n",
" results_eff = ident_eff.identify_effect(\n",
" graph=G,\n",
" treatment_name=treatment_name,\n",
" outcome_name=outcome_name,\n",
" action_nodes=treatment_name,\n",
" outcome_nodes=outcome_name,\n",
" observed_nodes=observed_node_names\n",
" )\n",
"except ValueError as e:\n",
" print(e)"
Expand Down Expand Up @@ -475,9 +489,7 @@
" (\"R\", {\"cost\": 2}),\n",
" (\"T\", {\"cost\": 1}),\n",
"]\n",
"G = CausalGraph(\n",
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
")"
"G = build_graph_from_str(graph_str)"
]
},
{
Expand All @@ -504,7 +516,11 @@
")\n",
"print(\n",
" ident_mincost_eff.identify_effect(\n",
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
" graph=G, \n",
" action_nodes=treatment_name, \n",
" outcome_nodes=outcome_name, \n",
" observed_nodes=observed_node_names,\n",
" conditional_node_names=conditional_node_names\n",
" )\n",
")"
]
Expand All @@ -528,22 +544,19 @@
")\n",
"print(\n",
" ident_minimal_eff.identify_effect(\n",
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
" graph=G, \n",
" action_nodes=treatment_name,\n",
" outcome_nodes=outcome_name, \n",
" observed_nodes=observed_node_names,\n",
" conditional_node_names=conditional_node_names\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.10 ('dowhy-_zBapv7Q-py3.8')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
Loading

0 comments on commit 2a8e49a

Please sign in to comment.