-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
98195e8
commit 573341d
Showing
2 changed files
with
225 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "67f5c379-07fa-4d96-9b4c-bcb0acb8d034", | ||
"metadata": {}, | ||
"source": [ | ||
"# Counterfactual explanations\n", | ||
"\n", | ||
"Wildboar support time series counterfactual explanations. Currently, we\n", | ||
"implement three ways for computing counterfactuals as described by Karlsson\n", | ||
"et.al. (2019, 2020) and Samsten (2024).\n", | ||
"\n", | ||
"Here we examplify the use of nearest neighbors counterfactual metod.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "803571a8-456b-4698-a4d7-3a56c7503659", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"from sklearn.neighbors import KNeighborsClassifier as Sklearn_KNeighborsClassifier\n", | ||
"\n", | ||
"from wildboar.metrics import proximity_score\n", | ||
"from wildboar.datasets import load_dataset\n", | ||
"from wildboar.distance import KNeighborsClassifier\n", | ||
"from wildboar.explain.counterfactual import KNeighborsCounterfactual, counterfactuals" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ed2902e6-2b85-47e6-9d1a-6b8190089abd", | ||
"metadata": {}, | ||
"source": [ | ||
"First, we load a dataset.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "69242bbe-6733-4382-bb0c-11e5a3966f68", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"X, y = load_dataset(\"ECG200\")\n", | ||
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b30fe62c", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, we define a classifier. Here we are using the nearest neighbor classifier from `scikit-learn`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "18db6880-f717-4737-8ce1-2e392d457800", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sk_nn = Sklearn_KNeighborsClassifier(n_neighbors=5, metric=\"euclidean\")\n", | ||
"sk_nn.fit(X_train, y_train)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "72a14ac2", | ||
"metadata": {}, | ||
"source": [ | ||
"We can also use the nearest neighbors classifier from `wildboar` to have support for elastic distance measures." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "16dee2ba-6b14-46da-b4c9-51ebaa08793e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"nn = KNeighborsClassifier(n_neighbors=5, metric=\"dtw\", metric_params={\"r\": 0.5})\n", | ||
"nn.fit(X_train, y_train)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e8a29db8", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, we define a function that given an `estimator` and a collection of time series returns a counterfactual sample for each sample in `X`, the predicted label for each sample and the counterfactual prediction." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "67fc7fac-e1d8-4934-9a4c-d039169f6520", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def find_counterfactuals(estimator, X):\n", | ||
" y_pred = estimator.predict(X)\n", | ||
" y_desired = np.empty_like(y_pred)\n", | ||
"\n", | ||
" # Store an array of the desired label for each sample.\n", | ||
" # We assume a binary classification task and the the desired\n", | ||
" # label is the inverse of the predicted label.\n", | ||
" a, b = estimator.classes_\n", | ||
" y_desired[y_pred == a] = b\n", | ||
" y_desired[y_pred == b] = a\n", | ||
"\n", | ||
" # Initialize the explainer, using the medoid approach.\n", | ||
" explainer = KNeighborsCounterfactual(random_state=1, method=\"medoid\")\n", | ||
" explainer.fit(estimator)\n", | ||
"\n", | ||
" # Explain each sample in X as the desired label in y_desired\n", | ||
" X_cf = explainer.explain(X, y_desired)\n", | ||
" return X_cf, y_pred, estimator.predict(X_cf)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1694b821", | ||
"metadata": {}, | ||
"source": [ | ||
"Given the test samples, we compute the counterfactuals." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "baa6c0f0-2b81-4851-bff7-e0b35c3cae66", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"X_cf, y_pred, cf_pred = find_counterfactuals(nn, X_test)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a03c21f3", | ||
"metadata": {}, | ||
"source": [ | ||
"We plot the counterfactual (blue) and the original time series (red) and show the average time series for the predicted label." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9a320991-541f-43a2-9647-217aa5e546fc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"i = 4\n", | ||
"plt.plot(\n", | ||
" X_test[i],\n", | ||
" color=\"red\",\n", | ||
" label=\"original (y_pred = %d, y_actual = %d)\" % (y_pred[i], y_test[i]),\n", | ||
")\n", | ||
"plt.plot(X_cf[i], color=\"blue\", label=\"counterfactual (y = %d)\" % cf_pred[i])\n", | ||
"plt.plot(\n", | ||
" np.mean(X_test[y_test == cf_pred[i]], axis=0),\n", | ||
" color=\"gray\",\n", | ||
" linestyle=\"dashed\",\n", | ||
" label=\"mean of X with y = %d\" % cf_pred[i],\n", | ||
")\n", | ||
"plt.legend()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "62f5aee4", | ||
"metadata": {}, | ||
"source": [ | ||
"We can evaluate the performance of counterfactuals using e.g., proximity." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "886e56aa", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"proximity_score(X_test, X_cf)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "cef8d485", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters