-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add notebook replication measure of global surrogate model (#506)
Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
- Loading branch information
Showing
2 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,344 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Copyright (c) Microsoft Corporation. All rights reserved.\n", | ||
"\n", | ||
"Licensed under the MIT License." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Explain regression model predictions using MimicExplainer\n", | ||
"_**This notebook showcases how to use the interpret-community repo to help interpret and inspect replication metrics for predictions given by the MimicExplainer's surrogate model for a regression model.**_\n", | ||
"\n", | ||
"\n", | ||
"## Table of Contents\n", | ||
"\n", | ||
"1. [Introduction](#Introduction)\n", | ||
"1. [Project](#Project)\n", | ||
"1. [Setup](#Setup)\n", | ||
"1. [Run model explainer locally at training time](#Explain)\n", | ||
" 1. Load the housing house price data\n", | ||
" 1. Train a GradientBoosting regression model, which you want to explain\n", | ||
" 1. Initialize the `MimicExplainer`\n", | ||
" 1. Inspect replication metrics to see how close is the global surrogate model to the trained model\n", | ||
" 1. Try different surrogate models to see if the replication measure changes\n", | ||
" 1. Generate global explanations\n", | ||
" 1. Generate local explanations\n", | ||
"1. [Next steps](#Next)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"<a id='Introduction'></a>\n", | ||
"## 1. Introduction\n", | ||
"\n", | ||
"This notebook illustrates how to use interpret-community to help interpret regression model predictions at training time using `MimicExplainer`. It demonstrates how to determine which global surrogate model better replicates the trained model by using metrics like `r2_score`. It also demonstrates the API calls needed to obtain the global and local interpretations.\n", | ||
"\n", | ||
"Three tabular data explainers are demonstrated: \n", | ||
"- MimicExplainer (global surrogate)\n", | ||
"\n", | ||
"| ![Interpretability Toolkit Architecture](./img/interpretability-architecture.png) |\n", | ||
"|:--:|\n", | ||
"| *Interpretability Toolkit Architecture* |\n", | ||
"\n", | ||
"<a id='Project'></a> \n", | ||
"## 2. Project\n", | ||
"\n", | ||
"The goal of this project is to predict California Housing Prices by using scikit-learn and locally running the model explainer:\n", | ||
"\n", | ||
"1. Train a GradientBoosting regression model using Scikit-learn\n", | ||
"2. Use `get_surrogate_model_replication_measure()` to determine how close are the predictions of the global surrogate model are to the regression model trained in step 1.\n", | ||
"3. Run `explain_model()` globally and locally with full dataset in local mode, which doesn't contact any Azure services.\n", | ||
"\n", | ||
"<a id='Setup'></a>\n", | ||
"## 3. Setup\n", | ||
"\n", | ||
"If you are using Jupyter notebooks, the extensions should be installed automatically with the package.\n", | ||
"If you are using Jupyter Labs run the following command:\n", | ||
"```\n", | ||
"(myenv) $ jupyter labextension install @jupyter-widgets/jupyterlab-manager\n", | ||
"```\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"<a id='Explain'></a>\n", | ||
"## 4. Run model explainer locally at training time" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn import datasets\n", | ||
"from sklearn.ensemble import GradientBoostingRegressor\n", | ||
"\n", | ||
"from interpret.ext.blackbox import MimicExplainer\n", | ||
"# You can use one of the following four interpretable models as a global surrogate to the black box model\n", | ||
"from interpret.ext.glassbox import LGBMExplainableModel\n", | ||
"from interpret.ext.glassbox import LinearExplainableModel\n", | ||
"from interpret.ext.glassbox import SGDExplainableModel\n", | ||
"from interpret.ext.glassbox import DecisionTreeExplainableModel\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Load the housing house price data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"housing_data = datasets.fetch_california_housing()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Split data into train and test\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"x_train, x_test, y_train, y_test = train_test_split(housing_data.data, housing_data.target, test_size=0.2, random_state=0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Train a GradientBoosting regression model, which you want to explain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"reg = GradientBoostingRegressor(n_estimators=100, max_depth=4,\n", | ||
" learning_rate=0.1, loss='huber',\n", | ||
" random_state=1)\n", | ||
"model = reg.fit(x_train, y_train)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Initialize the `MimicExplainer`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# augment_data is optional and if true, oversamples the initialization examples to improve surrogate model accuracy to fit original model. Useful for high-dimensional data where the number of rows is less than the number of columns. \n", | ||
"# max_num_of_augmentations is optional and defines max number of times we can increase the input data size.\n", | ||
"# LGBMExplainableModel can be replaced with LinearExplainableModel, SGDExplainableModel, or DecisionTreeExplainableModel\n", | ||
"explainer = MimicExplainer(model, \n", | ||
" x_train, \n", | ||
" LGBMExplainableModel, \n", | ||
" augment_data=True, \n", | ||
" max_num_of_augmentations=10, \n", | ||
" features=housing_data.feature_names)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Inspect replication metrics to see how close is the global surrogate model to the trained model\n", | ||
"For regression model the function `get_surrogate_model_replication_measure()` will return the `r2_score` between the predictions of the trained model and the predictions of the global surrogate model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"explainer.get_surrogate_model_replication_measure(training_data=x_train)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Try different surrogate models to see if the replication measure changes\n", | ||
"Let's choose a different global surrogate (out of `LinearExplainableModel`, `SGDExplainableModel`, or `DecisionTreeExplainableModel`) and see how the replication metrics change" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"explainer = MimicExplainer(model, \n", | ||
" x_train, \n", | ||
" LinearExplainableModel, \n", | ||
" augment_data=True, \n", | ||
" max_num_of_augmentations=10, \n", | ||
" features=housing_data.feature_names)\n", | ||
" \n", | ||
"explainer.get_surrogate_model_replication_measure(training_data=x_train)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Generate global explanations\n", | ||
"Explain overall model predictions (global explanation)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n", | ||
"# x_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n", | ||
"global_explanation = explainer.explain_global(x_test)\n", | ||
"\n", | ||
"# Note: if you used the PFIExplainer in the previous step, use the next line of code instead\n", | ||
"# global_explanation = explainer.explain_global(x_test, true_labels=y_test)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Sorted SHAP values \n", | ||
"print('ranked global importance values: {}'.format(global_explanation.get_ranked_global_values()))\n", | ||
"# Corresponding feature names\n", | ||
"print('ranked global importance names: {}'.format(global_explanation.get_ranked_global_names()))\n", | ||
"# Feature ranks (based on original order of features)\n", | ||
"print('global importance rank: {}'.format(global_explanation.global_importance_rank))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Print out a dictionary that holds the sorted feature importance names and values\n", | ||
"print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Explain overall model predictions as a collection of local (instance-level) explanations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# feature shap values for all features and all data points in the training data\n", | ||
"print('local importance values: {}'.format(global_explanation.local_importance_values))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Generate local explanations\n", | ||
"Explain local data points (individual instances)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# E.g., Explain the first data point in the test set\n", | ||
"local_explanation = explainer.explain_local(x_test[0,:])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Sorted local feature importance information; reflects the original feature order\n", | ||
"sorted_local_importance_names = local_explanation.get_ranked_local_names()\n", | ||
"sorted_local_importance_values = local_explanation.get_ranked_local_values()\n", | ||
"\n", | ||
"print('sorted local importance names: {}'.format(sorted_local_importance_names))\n", | ||
"print('sorted local importance values: {}'.format(sorted_local_importance_values))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"<a id='Next'></a>\n", | ||
"## 5. Next steps\n", | ||
"Learn about other use cases of the interpret-community package:\n", | ||
" \n", | ||
"1. [Training time: binary classification problem](./explain-binary-classification-local.ipynb)\n", | ||
"1. [Training time: multiclass classification problem](./explain-multiclass-classification-local.ipynb)\n", | ||
"1. Explain models with engineered features:\n", | ||
" 1. [Simple feature transformations](./simple-feature-transformations-explain-local.ipynb)\n", | ||
" 1. [Advanced feature transformations](./advanced-feature-transformations-explain-local.ipynb)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"authors": [ | ||
{ | ||
"name": "mesameki" | ||
} | ||
], | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters