Skip to content

Commit

Permalink
fix doc test
Browse files Browse the repository at this point in the history
  • Loading branch information
heimengqi committed Sep 8, 2021
1 parent 920952f commit 18b3f64
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 363 deletions.
4 changes: 2 additions & 2 deletions doc/spec/estimation/orthoiv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ For instance call:
import numpy as np
X = np.random.normal(size=(100, 3))
y = np.random.normal(size=(100,))
T = np.random.binomial(1, 0.5, size=(n,))
Z = np.random.binomial(1, 0.5, size=(n,))
T = np.random.binomial(1, 0.5, size=(100,))
Z = np.random.binomial(1, 0.5, size=(100,))
W = np.random.normal(size=(100, 10))

.. testcode::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 2;\n",
" var nbb_formatted_code = \"# imports\\nfrom econml.data.dynamic_panel_dgp import SemiSynthetic\\nfrom sklearn.linear_model import LassoCV, MultiTaskLassoCV\\nimport numpy as np\\nimport matplotlib.pyplot as plt\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"# imports\n",
"from econml.data.dynamic_panel_dgp import SemiSynthetic\n",
Expand Down Expand Up @@ -139,34 +114,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 3;\n",
" var nbb_formatted_code = \"# generate historical dataset (training purpose)\\nnp.random.seed(43)\\ndgp = SemiSynthetic()\\ndgp.create_instance()\\nn_periods = 4\\nn_units = 5000\\nn_treatments = dgp.n_treatments\\nrandom_seed = 43\\nthetas = np.random.uniform(0, 2, size=(dgp.n_proxies, n_treatments))\\n\\npanelX, panelT, panelY, panelGroups, true_effect = dgp.gen_data(\\n n_units, n_periods, thetas, random_seed\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"# generate historical dataset (training purpose)\n",
"np.random.seed(43)\n",
Expand All @@ -185,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -196,30 +146,6 @@
"Treatment shape: (5000, 4, 3)\n",
"Controls shape: (5000, 4, 71)\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 4;\n",
" var nbb_formatted_code = \"# print panel data shape\\nprint(\\\"Outcome shape: \\\", panelY.shape)\\nprint(\\\"Treatment shape: \\\", panelT.shape)\\nprint(\\\"Controls shape: \\\", panelX.shape)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand All @@ -231,34 +157,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 5;\n",
" var nbb_formatted_code = \"# generate new dataset (testing purpose)\\nthetas_new = np.random.uniform(0, 2, size=(dgp.n_proxies, n_treatments))\\npanelXnew, panelTnew, panelYnew, panelGroupsnew, true_effect_new = dgp.gen_data(\\n n_units, n_periods, thetas_new, random_seed\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"# generate new dataset (testing purpose)\n",
"thetas_new = np.random.uniform(0, 2, size=(dgp.n_proxies, n_treatments))\n",
Expand All @@ -269,7 +170,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -278,30 +179,6 @@
"text": [
"True Long-term Effect for each investment: [0.90994672 0.709811 2.45310877]\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 6;\n",
" var nbb_formatted_code = \"# print true long term effect\\ntrue_longterm_effect = np.sum(true_effect_new, axis=0)\\nprint(\\\"True Long-term Effect for each investment: \\\", true_longterm_effect)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand All @@ -324,34 +201,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 7;\n",
" var nbb_formatted_code = \"# Helper function to reshape the panel data\\ndef long(x): # reshape the panel data to (n_units * n_periods, -1)\\n n_units = x.shape[0]\\n n_periods = x.shape[1]\\n return (\\n x.reshape(n_units * n_periods)\\n if np.ndim(x) == 2\\n else x.reshape(n_units * n_periods, -1)\\n )\\n\\n\\ndef wide(x): # reshape the panel data to (n_units, n_periods * d_x)\\n n_units = x.shape[0]\\n return x.reshape(n_units, -1)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"# Helper function to reshape the panel data\n",
"def long(x): # reshape the panel data to (n_units * n_periods, -1)\n",
Expand All @@ -371,32 +223,44 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 8;\n",
" var nbb_formatted_code = \"# on historical data construct adjusted outcomes\\nfrom econml.dynamic.dml import DynamicDML\\n\\npanelYadj = panelY.copy()\\n\\nest = DynamicDML(\\n model_y=LassoCV(max_iter=2000), model_t=MultiTaskLassoCV(max_iter=2000), cv=2\\n)\\nfor t in range(1, n_periods): # for each target period 1...m\\n # learn period effect for each period treatment on target period t\\n est.fit(\\n long(panelY[:, 1 : t + 1]),\\n long(panelT[:, 1 : t + 1, :]), # reshape data to long format\\n X=None,\\n W=long(panelX[:, 1 : t + 1, :]),\\n groups=long(panelGroups[:, 1 : t + 1]),\\n )\\n # remove effect of observed treatments\\n T1 = wide(panelT[:, 1 : t + 1, :])\\n panelYadj[:, t] = panelY[:, t] - est.effect(\\n T0=np.zeros_like(T1), T1=T1\\n ) # reshape data to wide format\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
"array([[0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],\n",
" [1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00],\n",
" [2.000e+00, 2.000e+00, 2.000e+00, 2.000e+00],\n",
" ...,\n",
" [4.997e+03, 4.997e+03, 4.997e+03, 4.997e+03],\n",
" [4.998e+03, 4.998e+03, 4.998e+03, 4.998e+03],\n",
" [4.999e+03, 4.999e+03, 4.999e+03, 4.999e+03]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
"panelGroups"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1.35952949 1.92451605 0.34684417]\n",
"[0.74662029 1.13138969 0.25069193 1.30585143 1.79051531 0.34597602]\n",
"[0.46734394 0.74952179 0.16292026 0.67056612 0.92299133 0.23686006\n",
" 1.36311063 1.91659314 0.34728767]\n"
]
}
],
"source": [
Expand All @@ -417,6 +281,7 @@
" W=long(panelX[:, 1 : t + 1, :]),\n",
" groups=long(panelGroups[:, 1 : t + 1]),\n",
" )\n",
" print(est.intercept_)\n",
" # remove effect of observed treatments\n",
" T1 = wide(panelT[:, 1 : t + 1, :])\n",
" panelYadj[:, t] = panelY[:, t] - est.effect(\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Double Machine Learning Examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2468,7 +2468,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Doubly Robust Learner and Interpretability.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Dynamic Double Machine Learning Examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 18b3f64

Please sign in to comment.