From 0befdac7d5710c0dfce5708d10e1d2dc33e9ecff Mon Sep 17 00:00:00 2001 From: Isak Samsten Date: Thu, 19 Oct 2023 21:23:07 +0200 Subject: [PATCH] Use argmin_distance in KMeans. --- src/wildboar/distance/_neighbors.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/wildboar/distance/_neighbors.py b/src/wildboar/distance/_neighbors.py index 3f88e0fdb6..06a4caf9de 100644 --- a/src/wildboar/distance/_neighbors.py +++ b/src/wildboar/distance/_neighbors.py @@ -174,14 +174,25 @@ def cost(self, x): ) def assign(self, x): - self.distance_ = pairwise_distance( - x, - self.centroids_, - dim="mean", - metric=self.metric, - metric_params=self.metric_params, - ) - self.assigned_ = self.distance_.argmin(axis=1) + if x.ndim == 2: + self.assigned_, self.distance_ = argmin_distance( + x, + self.centroids_, + k=1, + metric=self.metric, + metric_params=self.metric_params, + return_distance=True, + ) + self.assigned_ = np.ravel(self.assigned_) + else: + self.distance_ = pairwise_distance( + x, + self.centroids_, + dim="mean", + metric=self.metric, + metric_params=self.metric_params, + ) + self.assigned_ = self.distance_.argmin(axis=1) def update(self, x): for c in range(self.centroids_.shape[0]): @@ -472,7 +483,6 @@ def update(self): for idx in range(self._cluster_idx.shape[0]): cluster_idx = np.where(self.labels_ == idx)[0] if cluster_idx.shape[0] == 0: - print(f"empty cluster {idx}") continue cost = self._dist[cluster_idx, cluster_idx.reshape(-1, 1)].sum(axis=1)