Skip to content

Commit

Permalink
Add new example to documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
isaksamsten committed Feb 2, 2024
1 parent 98195e8 commit 573341d
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 1 deletion.
224 changes: 224 additions & 0 deletions docs/examples/counterfactuals.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "67f5c379-07fa-4d96-9b4c-bcb0acb8d034",
"metadata": {},
"source": [
"# Counterfactual explanations\n",
"\n",
"Wildboar support time series counterfactual explanations. Currently, we\n",
"implement three ways for computing counterfactuals as described by Karlsson\n",
"et.al. (2019, 2020) and Samsten (2024).\n",
"\n",
"Here we examplify the use of nearest neighbors counterfactual metod.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "803571a8-456b-4698-a4d7-3a56c7503659",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neighbors import KNeighborsClassifier as Sklearn_KNeighborsClassifier\n",
"\n",
"from wildboar.metrics import proximity_score\n",
"from wildboar.datasets import load_dataset\n",
"from wildboar.distance import KNeighborsClassifier\n",
"from wildboar.explain.counterfactual import KNeighborsCounterfactual, counterfactuals"
]
},
{
"cell_type": "markdown",
"id": "ed2902e6-2b85-47e6-9d1a-6b8190089abd",
"metadata": {},
"source": [
"First, we load a dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69242bbe-6733-4382-bb0c-11e5a3966f68",
"metadata": {},
"outputs": [],
"source": [
"X, y = load_dataset(\"ECG200\")\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)"
]
},
{
"cell_type": "markdown",
"id": "b30fe62c",
"metadata": {},
"source": [
"Next, we define a classifier. Here we are using the nearest neighbor classifier from `scikit-learn`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18db6880-f717-4737-8ce1-2e392d457800",
"metadata": {},
"outputs": [],
"source": [
"sk_nn = Sklearn_KNeighborsClassifier(n_neighbors=5, metric=\"euclidean\")\n",
"sk_nn.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"id": "72a14ac2",
"metadata": {},
"source": [
"We can also use the nearest neighbors classifier from `wildboar` to have support for elastic distance measures."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16dee2ba-6b14-46da-b4c9-51ebaa08793e",
"metadata": {},
"outputs": [],
"source": [
"nn = KNeighborsClassifier(n_neighbors=5, metric=\"dtw\", metric_params={\"r\": 0.5})\n",
"nn.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"id": "e8a29db8",
"metadata": {},
"source": [
"Next, we define a function that given an `estimator` and a collection of time series returns a counterfactual sample for each sample in `X`, the predicted label for each sample and the counterfactual prediction."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67fc7fac-e1d8-4934-9a4c-d039169f6520",
"metadata": {},
"outputs": [],
"source": [
"def find_counterfactuals(estimator, X):\n",
" y_pred = estimator.predict(X)\n",
" y_desired = np.empty_like(y_pred)\n",
"\n",
" # Store an array of the desired label for each sample.\n",
" # We assume a binary classification task and the the desired\n",
" # label is the inverse of the predicted label.\n",
" a, b = estimator.classes_\n",
" y_desired[y_pred == a] = b\n",
" y_desired[y_pred == b] = a\n",
"\n",
" # Initialize the explainer, using the medoid approach.\n",
" explainer = KNeighborsCounterfactual(random_state=1, method=\"medoid\")\n",
" explainer.fit(estimator)\n",
"\n",
" # Explain each sample in X as the desired label in y_desired\n",
" X_cf = explainer.explain(X, y_desired)\n",
" return X_cf, y_pred, estimator.predict(X_cf)"
]
},
{
"cell_type": "markdown",
"id": "1694b821",
"metadata": {},
"source": [
"Given the test samples, we compute the counterfactuals."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "baa6c0f0-2b81-4851-bff7-e0b35c3cae66",
"metadata": {},
"outputs": [],
"source": [
"X_cf, y_pred, cf_pred = find_counterfactuals(nn, X_test)"
]
},
{
"cell_type": "markdown",
"id": "a03c21f3",
"metadata": {},
"source": [
"We plot the counterfactual (blue) and the original time series (red) and show the average time series for the predicted label."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a320991-541f-43a2-9647-217aa5e546fc",
"metadata": {},
"outputs": [],
"source": [
"i = 4\n",
"plt.plot(\n",
" X_test[i],\n",
" color=\"red\",\n",
" label=\"original (y_pred = %d, y_actual = %d)\" % (y_pred[i], y_test[i]),\n",
")\n",
"plt.plot(X_cf[i], color=\"blue\", label=\"counterfactual (y = %d)\" % cf_pred[i])\n",
"plt.plot(\n",
" np.mean(X_test[y_test == cf_pred[i]], axis=0),\n",
" color=\"gray\",\n",
" linestyle=\"dashed\",\n",
" label=\"mean of X with y = %d\" % cf_pred[i],\n",
")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"id": "62f5aee4",
"metadata": {},
"source": [
"We can evaluate the performance of counterfactuals using e.g., proximity."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "886e56aa",
"metadata": {},
"outputs": [],
"source": [
"proximity_score(X_test, X_cf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cef8d485",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 1 addition & 1 deletion docs/examples/hydra.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 573341d

Please sign in to comment.