Skip to content

Commit

Permalink
Use argmin_distance in KMeans.
Browse files Browse the repository at this point in the history
  • Loading branch information
isaksamsten committed Oct 19, 2023
1 parent 049cde6 commit 0befdac
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/wildboar/distance/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,25 @@ def cost(self, x):
)

def assign(self, x):
self.distance_ = pairwise_distance(
x,
self.centroids_,
dim="mean",
metric=self.metric,
metric_params=self.metric_params,
)
self.assigned_ = self.distance_.argmin(axis=1)
if x.ndim == 2:
self.assigned_, self.distance_ = argmin_distance(
x,
self.centroids_,
k=1,
metric=self.metric,
metric_params=self.metric_params,
return_distance=True,
)
self.assigned_ = np.ravel(self.assigned_)
else:
self.distance_ = pairwise_distance(
x,
self.centroids_,
dim="mean",
metric=self.metric,
metric_params=self.metric_params,
)
self.assigned_ = self.distance_.argmin(axis=1)

def update(self, x):
for c in range(self.centroids_.shape[0]):
Expand Down Expand Up @@ -472,7 +483,6 @@ def update(self):
for idx in range(self._cluster_idx.shape[0]):
cluster_idx = np.where(self.labels_ == idx)[0]
if cluster_idx.shape[0] == 0:
print(f"empty cluster {idx}")
continue

cost = self._dist[cluster_idx, cluster_idx.reshape(-1, 1)].sum(axis=1)
Expand Down

0 comments on commit 0befdac

Please sign in to comment.