From 7563635d9b6eea4d0db26cd58e93ec2298f6faac Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Thu, 28 Mar 2024 10:54:18 +0100 Subject: [PATCH 1/2] fix test --- tests/unit_tests/explainer/test_smart_plotter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index 57f279a0..9a7133e3 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -108,6 +108,7 @@ def setUp(self): self.smart_explainer._case, self.smart_explainer._classes = check_model(model) self.smart_explainer.state = MultiDecorator(SmartState()) self.smart_explainer.y_pred = None + self.smart_explainer.proba_values = None self.smart_explainer.features_desc = dict(self.x_init.nunique()) self.smart_explainer.features_compacity = self.features_compacity @@ -863,7 +864,7 @@ def test_contribution_plot_8(self): xpl.model = model np_hv = [f"Id: {x}
Predict: {y}" for x, y in zip(xpl.x_init.index, xpl.y_pred.iloc[:, 0].tolist())] np_hv.sort() - output = xpl.plot.contribution_plot(col) + output = xpl.plot.contribution_plot(col, proba=False) annot_list = [] for data_plot in output.data: annot_list.extend(data_plot.hovertext.tolist()) @@ -895,7 +896,7 @@ def test_contribution_plot_9(self): model = lambda: None model.classes_ = np.array([0, 1]) xpl.model = model - output = xpl.plot.contribution_plot(col, max_points=39) + output = xpl.plot.contribution_plot(col, max_points=39, proba=False) assert len(output.data) == 4 for elem in output.data: assert elem.type == "violin" From ad0d94178ccf3a94ea1728ac99b0f52b806e027f Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Thu, 28 Mar 2024 13:28:31 +0100 Subject: [PATCH 2/2] fix2 test --- tests/unit_tests/explainer/test_smart_plotter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index 9a7133e3..afd19220 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -1267,6 +1267,9 @@ def test_features_importance_4(self): def test_local_pred_1(self): xpl = self.smart_explainer + xpl.proba_values = pd.DataFrame( + data=np.array([[0.4, 0.6], [0.3, 0.7]]), columns=["class_1", "class_2"], index=xpl.x_encoded.index.values + ) output = xpl.plot.local_pred("person_A", label=0) assert isinstance(output, float)