Skip to content

Commit

Permalink
SilhouetteVisualizer add support for more estimators (#1294)
Browse files Browse the repository at this point in the history
Signed-off-by: Benjamin Bengfort <benjamin@bengfort.com>
Co-authored-by: Benjamin Bengfort <benjamin@bengfort.com>
Co-authored-by: Larry Gray <lwgray@gmail.com>
  • Loading branch information
3 people authored Jul 5, 2023
1 parent 7a3c94c commit f7a8e95
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 30 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 31 additions & 19 deletions tests/test_cluster/test_silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import sys
import pytest
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.cluster import SpectralClustering, AgglomerativeClustering

from unittest import mock
from tests.base import VisualTestCase

from yellowbrick.datasets import load_nfl
from yellowbrick.cluster.silhouette import SilhouetteVisualizer, silhouette_visualizer


Expand All @@ -53,7 +54,6 @@ def test_integrated_kmeans_silhouette(self):
n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0
)


fig = plt.figure()
ax = fig.add_subplot()

Expand All @@ -62,7 +62,6 @@ def test_integrated_kmeans_silhouette(self):
visualizer.finalize()

self.assert_images_similar(visualizer, remove_legend=True)


@pytest.mark.xfail(sys.platform == "win32", reason="images not close on windows")
def test_integrated_mini_batch_kmeans_silhouette(self):
Expand All @@ -84,7 +83,6 @@ def test_integrated_mini_batch_kmeans_silhouette(self):
visualizer.finalize()

self.assert_images_similar(visualizer, remove_legend=True)


@pytest.mark.skip(reason="no negative silhouette example available yet")
def test_negative_silhouette_score(self):
Expand All @@ -103,7 +101,6 @@ def test_colormap_silhouette(self):
n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0
)


fig = plt.figure()
ax = fig.add_subplot()

Expand Down Expand Up @@ -138,7 +135,7 @@ def test_colors_silhouette(self):
visualizer.finalize()

self.assert_images_similar(visualizer, remove_legend=True)

def test_colormap_as_colors_silhouette(self):
"""
Test no exceptions for modifying the colors in a silhouette visualizer
Expand All @@ -162,7 +159,7 @@ def test_colormap_as_colors_silhouette(self):
3.2 if sys.platform == "win32" else 0.01
) # Fails on AppVeyor with RMS 3.143
self.assert_images_similar(visualizer, remove_legend=True, tol=tol)

def test_quick_method(self):
"""
Test the quick method producing a valid visualization
Expand All @@ -177,29 +174,44 @@ def test_quick_method(self):

self.assert_images_similar(oz)

@pytest.mark.xfail(
reason="""third test fails with AssertionError: Expected fit
to be called once. Called 0 times."""
)
def test_with_fitted(self):
"""
Test that visualizer properly handles an already-fitted model
"""
X, y = load_nfl(return_dataset=True).to_numpy()

model = MiniBatchKMeans().fit(X, y)
X, y = make_blobs(
n_samples=100, n_features=5, centers=3, shuffle=False, random_state=112
)
model = MiniBatchKMeans().fit(X)
labels = model.predict(X)

with mock.patch.object(model, "fit") as mockfit:
oz = SilhouetteVisualizer(model)
oz.fit(X, y)
oz.fit(X)
mockfit.assert_not_called()

with mock.patch.object(model, "fit") as mockfit:
oz = SilhouetteVisualizer(model, is_fitted=True)
oz.fit(X, y)
oz.fit(X)
mockfit.assert_not_called()

with mock.patch.object(model, "fit") as mockfit:
with mock.patch.object(model, "fit_predict", return_value=labels) as mockfit:
oz = SilhouetteVisualizer(model, is_fitted=False)
oz.fit(X, y)
mockfit.assert_called_once_with(X, y)
oz.fit(X)
mockfit.assert_called_once_with(X, None)

@pytest.mark.parametrize(
"model",
[SpectralClustering, AgglomerativeClustering],
)
def test_clusterer_without_predict(self, model):
"""
Assert that clustering estimators that don't implement
a predict() method utilize fit_predict()
"""
X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]])
try:
visualizer = SilhouetteVisualizer(model(n_clusters=2))
visualizer.fit(X)
visualizer.finalize()
except AttributeError:
self.fail("could not use fit or fit_predict methods")
93 changes: 82 additions & 11 deletions yellowbrick/cluster/silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@

from sklearn.metrics import silhouette_score, silhouette_samples

try:
from sklearn.metrics.pairwise import _VALID_METRICS
except ImportError:
_VALID_METRICS = [
"cityblock",
"cosine",
"euclidean",
"l1",
"l2",
"manhattan",
"braycurtis",
"canberra",
"chebyshev",
"correlation",
"dice",
"hamming",
"jaccard",
"kulsinski",
"mahalanobis",
"minkowski",
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
]

from yellowbrick.utils import check_fitted
from yellowbrick.style import resolve_colors
from yellowbrick.cluster.base import ClusteringScoreVisualizer
Expand Down Expand Up @@ -113,7 +142,6 @@ class SilhouetteVisualizer(ClusteringScoreVisualizer):
"""

def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs):

# Initialize the visualizer bases
super(SilhouetteVisualizer, self).__init__(
estimator, ax=ax, is_fitted=is_fitted, **kwargs
Expand All @@ -130,23 +158,47 @@ def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs):
def fit(self, X, y=None, **kwargs):
"""
Fits the model and generates the silhouette visualization.
Unlike other visualizers that use the score() method to draw the results, this
visualizer errs on visualizing on fit since this is when the clusters are
computed. This means that a predict call is required in fit (or a fit_predict)
in order to produce the visualization.
"""
# TODO: decide to use this method or the score method to draw.
# NOTE: Probably this would be better in score, but the standard score
# is a little different and I'm not sure how it's used.

# If the estimator is not fitted, fit it; then call predict to get the labels
# for computing the silhoutte score on. If the estimator is already fitted, then
# attempt to predict the labels, but if the estimator is stateless, fit and
# predict on the data specified. At the end of this block, no matter the fitted
# state of the estimator and the method, we should have cluster labels for X.
if not check_fitted(self.estimator, is_fitted_by=self.is_fitted):
# Fit the wrapped estimator
self.estimator.fit(X, y, **kwargs)
if hasattr(self.estimator, "fit_predict"):
labels = self.estimator.fit_predict(X, y, **kwargs)
else:
self.estimator.fit(X, y, **kwargs)
labels = self.estimator.predict(X)
else:
if hasattr(self.estimator, "predict"):
labels = self.estimator.predict(X)
else:
labels = self.estimator.fit_predict(X, y, **kwargs)

# Get the properties of the dataset
self.n_samples_ = X.shape[0]
self.n_clusters_ = self.estimator.n_clusters

# Compute the number of available clusters from the estimator
if hasattr(self.estimator, "n_clusters"):
self.n_clusters_ = self.estimator.n_clusters
else:
unique_labels = set(labels)
n_noise_clusters = 1 if -1 in unique_labels else 0
self.n_clusters_ = len(unique_labels) - n_noise_clusters

# Identify the distance metric to use for silhouette scoring
metric = self._identify_silhouette_metric()

# Compute the scores of the cluster
labels = self.estimator.predict(X)
self.silhouette_score_ = silhouette_score(X, labels)
self.silhouette_samples_ = silhouette_samples(X, labels)
self.silhouette_score_ = silhouette_score(X, labels, metric=metric)
self.silhouette_samples_ = silhouette_samples(X, labels, metric=metric)

# Draw the silhouette figure
self.draw(labels)
Expand Down Expand Up @@ -185,7 +237,6 @@ def draw(self, labels):
# For each cluster, plot the silhouette scores
self.y_tick_pos_ = []
for idx in range(self.n_clusters_):

# Collect silhouette scores for samples in the current cluster .
values = self.silhouette_samples_[labels == idx]
values.sort()
Expand Down Expand Up @@ -260,6 +311,26 @@ def finalize(self):
# Show legend (Average Silhouette Score axis)
self.ax.legend(loc="best")

def _identify_silhouette_metric(self):
"""
The Silhouette metric must be one of the distance options allowed by
metrics.pairwise.pairwise_distances or a callable. This method attempts to
discover a valid distance metric from the underlying estimator or returns
"euclidean" by default.
"""
if hasattr(self.estimator, "metric"):
if callable(self.estimator.metric):
return self.estimator.metric

if self.estimator.metric in _VALID_METRICS:
return self.estimator.metric

if hasattr(self.estimator, "affinity"):
if self.estimator.affinity in _VALID_METRICS:
return self.estimator.affinity

return "euclidean"


##########################################################################
## Quick Method
Expand Down

0 comments on commit f7a8e95

Please sign in to comment.