From 3594f71c809225300679722b054f48846cfae948 Mon Sep 17 00:00:00 2001 From: RobertSamoilescu Date: Wed, 1 Mar 2023 14:02:06 +0000 Subject: [PATCH] Updated Shap saving. (#881) * Updated Shap saving. Fixes savings for distributed KernelShap and avoids saving internal explainer. * Fixed typo for the TreeShap saving function. --- alibi/saving.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/alibi/saving.py b/alibi/saving.py index f11f9cafd..439777527 100644 --- a/alibi/saving.py +++ b/alibi/saving.py @@ -232,14 +232,25 @@ def _save_AnchorText(explainer: 'AnchorText', path: Union[str, os.PathLike]) -> explainer.perturbation = perturbation -def _save_KernelShap(explainer: 'KernelShap', path: Union[str, os.PathLike]) -> None: - # TODO: save internal shap objects using native pickle? +def _save_Shap(explainer: Union['KernelShap', 'TreeShap'], path: Union[str, os.PathLike]) -> None: + # set the internal explainer object to avoid saving it. The internal explainer + # object is recreated when in the `reset_predictor` function call. + _explainer = explainer._explainer + explainer._explainer = None + + # simple save which does not save the predictor _simple_save(explainer, path) + # reset the internal explainer object + explainer._explainer = _explainer -def _save_TreelShap(explainer: 'TreeShap', path: Union[str, os.PathLike]) -> None: - # TODO: save internal shap objects using native pickle? - _simple_save(explainer, path) + +def _save_KernelShap(explainer: 'KernelShap', path: Union[str, os.PathLike]) -> None: + _save_Shap(explainer, path) + + +def _save_TreeShap(explainer: 'TreeShap', path: Union[str, os.PathLike]) -> None: + _save_Shap(explainer, path) def _save_CounterfactualRL(explainer: 'CounterfactualRL', path: Union[str, os.PathLike]) -> None: