Skip to content

Commit

Permalink
Release 0.3.3 (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
confoundry authored Jun 27, 2023
1 parent 8e9cc2d commit 9fcb726
Show file tree
Hide file tree
Showing 38 changed files with 1,191 additions and 1,050 deletions.
Empty file removed examples/__init__.py
Empty file.
458 changes: 229 additions & 229 deletions examples/csuite_example.ipynb

Large diffs are not rendered by default.

51 changes: 36 additions & 15 deletions examples/multi_investment_sales_attribution.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -20,6 +21,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -55,21 +57,19 @@
"from pytorch_lightning.callbacks import TQDMProgressBar\n",
"from tensordict import TensorDict\n",
"\n",
"from causica.distributions import (\n",
" ContinuousNoiseDist,\n",
" SEMDistributionModule,\n",
")\n",
"from causica.distributions import ContinuousNoiseDist\n",
"from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule\n",
"from causica.lightning.modules.deci_module import DECIModule\n",
"from causica.sem.distribution_parameters_sem import DistributionParametersSEM\n",
"from causica.sem.sem_distribution import SEMDistributionModule\n",
"from causica.sem.structural_equation_model import ite\n",
"from causica.training.auglag import AugLagLRConfig\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"%matplotlib inline"
"test_run = bool(os.environ.get(\"TEST_RUN\", False)) # used by testing to run the notebook as a script"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -110,6 +110,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -384,6 +385,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -439,21 +441,27 @@
"fig, axis = plt.subplots(1, 1, figsize=(8, 8))\n",
"labels = {node: i for i, node in enumerate(true_adj.nodes)}\n",
"\n",
"layout = nx.nx_agraph.graphviz_layout(true_adj, prog=\"dot\")\n",
"try:\n",
" layout = nx.nx_agraph.graphviz_layout(true_adj, prog=\"dot\")\n",
"except (ModuleNotFoundError, ImportError):\n",
" layout = nx.layout.spring_layout(true_adj)\n",
"\n",
"for node, i in labels.items():\n",
" axis.scatter(layout[node][0], layout[node][1], label=f\"{i}: {node}\")\n",
"axis.legend()\n",
"nx.draw_networkx(true_adj, pos=layout, with_labels=True, arrows=True, labels=labels, ax=axis)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Discover the Causal Graph"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -478,6 +486,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -497,6 +506,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -523,6 +533,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -542,6 +553,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -553,6 +565,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -596,7 +609,8 @@
"\n",
"trainer = pl.Trainer(\n",
" accelerator=\"auto\",\n",
" max_epochs=int(os.environ.get(\"MAX_EPOCH\", 2000)), # used by testing to run the notebook as a script\n",
" max_epochs=2000,\n",
" fast_dev_run=test_run,\n",
" callbacks=[TQDMProgressBar(refresh_rate=19)],\n",
" enable_checkpointing=False,\n",
")"
Expand Down Expand Up @@ -668,6 +682,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -725,6 +740,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -733,6 +749,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -742,6 +759,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -769,7 +787,7 @@
],
"source": [
"revenue_estimated_ate = {}\n",
"num_samples = 20000\n",
"num_samples = 10 if test_run else 20000\n",
"sample_shape = torch.Size([num_samples])\n",
"transform = data_module.normalizer.transform_modules[outcome]().inv\n",
"\n",
Expand All @@ -791,6 +809,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -825,6 +844,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -857,13 +877,14 @@
"source": [
"revenue_estimated_ite = {}\n",
"\n",
"base_noise = sem.sample_to_noise(data_module.dataset_train)\n",
"\n",
"for treatment in treatment_columns:\n",
" base_noise = sem.sample_to_noise(data_module.dataset_train)\n",
" intervention_a = TensorDict({treatment: torch.tensor([1.0])}, batch_size=tuple())\n",
" do_a_cfs = transform(sem.do(interventions=intervention_a).noise_to_sample(base_noise)[outcome])\n",
" intervention_b = TensorDict({treatment: torch.tensor([0.0])}, batch_size=tuple())\n",
" do_b_cfs = transform(sem.do(interventions=intervention_b).noise_to_sample(base_noise)[outcome])\n",
" revenue_estimated_ite[treatment] = (do_a_cfs - do_b_cfs).cpu().detach().numpy()[:, 0]\n",
" do_sem = sem.do(interventions=TensorDict({treatment: torch.tensor([1.0])}, batch_size=tuple()))\n",
" do_a_cfs = transform(do_sem.noise_to_sample(base_noise)[outcome]).cpu().detach().numpy()[:, 0]\n",
" do_sem = sem.do(interventions=TensorDict({treatment: torch.tensor([0.0])}, batch_size=tuple()))\n",
" do_b_cfs = transform(do_sem.noise_to_sample(base_noise)[outcome]).cpu().detach().numpy()[:, 0]\n",
" revenue_estimated_ite[treatment] = do_a_cfs - do_b_cfs\n",
"\n",
"revenue_estimated_ite"
]
Expand Down
Loading

0 comments on commit 9fcb726

Please sign in to comment.