Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update kernelshap_tabular_land_atmosphere.ipynb: increase timeout #871

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@
},
{
"cell_type": "markdown",
"id": "80411a9d-881e-4196-8559-17aaadd15841",
"id": "ddb1e4f0-2674-4242-bcd8-abf66f97c611",
"metadata": {},
"source": [
"#### 5 - Run the explainer at one location, several data instances (here as an example one month time series)\n",
Expand Down Expand Up @@ -805,6 +805,24 @@
"background_data = x_train.drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()"
]
},
{
"cell_type": "markdown",
"id": "8b612e55-e1ec-40dc-b189-65d90ffb2b1c",
"metadata": {},
"source": [
"This step takes a few minutes, so not suitable for github actions. If you want to run this step locally, set `locally_run = True`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "59a54eaa-f6f2-42b5-8849-aceb37b06156",
"metadata": {},
"outputs": [],
"source": [
"locally_run = False"
]
},
{
"cell_type": "code",
"execution_count": 14,
Expand All @@ -821,11 +839,12 @@
],
"source": [
"# run explainer over time series, this might take a few minutes\n",
"explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
"\n",
"print(\"Dianna is done!\") "
"if locally_run:\n",
" explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
" \n",
" print(\"Dianna is done!\") "
]
},
{
Expand All @@ -846,30 +865,31 @@
}
],
"source": [
"# create shap_values object\n",
"shap_values = Explanation(explanations[key])\n",
"shap_values.feature_names = features.columns\n",
"\n",
"# create comparison plot: predictions vs test data \n",
"y_predict_time = runner(features.to_numpy())\n",
"y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
"comparison_plot(y_test_time, y_predict_time, show=False) \n",
"comparison_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create summary plot\n",
"shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
"summary_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create heatmap plot\n",
"shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
"heatmap_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# plot all three figures in one cell\n",
"figures = [comparison_img, heatmap_img, summary_img]\n",
"display_figures(figures, captions, 1, 3)"
"if locally_run:\n",
" # create shap_values object\n",
" shap_values = Explanation(explanations[key])\n",
" shap_values.feature_names = features.columns\n",
" \n",
" # create comparison plot: predictions vs test data \n",
" y_predict_time = runner(features.to_numpy())\n",
" y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
" comparison_plot(y_test_time, y_predict_time, show=False) \n",
" comparison_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create summary plot\n",
" shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
" summary_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create heatmap plot\n",
" shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
" heatmap_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # plot all three figures in one cell\n",
" figures = [comparison_img, heatmap_img, summary_img]\n",
" display_figures(figures, captions, 1, 3)"
]
},
{
Expand All @@ -887,9 +907,10 @@
}
],
"source": [
"relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
"cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
"print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
"if locally_run:\n",
" relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
" cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
" print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
]
},
{
Expand Down Expand Up @@ -947,12 +968,13 @@
}
],
"source": [
"# run explainer over time series, this might take a few minutes\n",
"explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
"\n",
"print(\"Dianna is done!\") "
"if locally_run:\n",
" # run explainer over time series, this might take a few minutes\n",
" explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
" \n",
" print(\"Dianna is done!\") "
]
},
{
Expand All @@ -973,30 +995,31 @@
}
],
"source": [
"# create shap_values object\n",
"shap_values = Explanation(explanations[key])\n",
"shap_values.feature_names = features.columns\n",
"\n",
"# create comparison plot: predictions vs test data \n",
"y_predict_time = runner(features.to_numpy())\n",
"y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
"comparison_plot(y_test_time, y_predict_time, show=False) \n",
"comparison_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create summary plot\n",
"shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
"summary_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create heatmap plot\n",
"shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
"heatmap_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# plot all three figures in one cell\n",
"figures = [comparison_img, heatmap_img, summary_img]\n",
"display_figures(figures, captions, 1, 3)"
"if locally_run:\n",
" # create shap_values object\n",
" shap_values = Explanation(explanations[key])\n",
" shap_values.feature_names = features.columns\n",
" \n",
" # create comparison plot: predictions vs test data \n",
" y_predict_time = runner(features.to_numpy())\n",
" y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
" comparison_plot(y_test_time, y_predict_time, show=False) \n",
" comparison_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create summary plot\n",
" shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
" summary_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create heatmap plot\n",
" shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
" heatmap_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # plot all three figures in one cell\n",
" figures = [comparison_img, heatmap_img, summary_img]\n",
" display_figures(figures, captions, 1, 3)"
]
},
{
Expand All @@ -1014,9 +1037,10 @@
}
],
"source": [
"relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
"cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
"print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
"if locally_run:\n",
" relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
" cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
" print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
]
},
{
Expand Down Expand Up @@ -1166,6 +1190,9 @@
}
],
"metadata": {
"execution": {
"timeout": 1800
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
Expand Down
Loading