-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcluster.py
216 lines (176 loc) · 7.22 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import jax
import jax.numpy as jnp
import flax.struct
import numpy as np
from dataclasses import dataclass
import logging
from functools import partial
from typing import Literal, Optional, TypeVar
from .treax import numpy as tjnp
from .treax import functional as F
from .utils import RNGKey, jax_jit, lax_scan
_log = logging.getLogger(__name__)
class KMeansState(flax.struct.PyTreeNode):
cluster_centers_: jax.Array
labels_: jax.Array
diff_: jax.Array
n_iter_: jax.Array
inertia_: jax.Array
random_key_: RNGKey
_KMeansStateT = TypeVar('_KMeansStateT', bound=KMeansState)
@dataclass
class KMeans:
k: int
init: str = 'k-means++'
n_init: int | Literal['auto'] = 10
max_iter: int = 300
tol: float = 0.0001
n_local_trials: Optional[int] = None
StateType: type[KMeansState] = KMeansState
@staticmethod
@partial(jax_jit, donate_argnums=0)
def _step(state: _KMeansStateT, x: jax.Array, mask: jax.Array) -> _KMeansStateT:
dist = F.distance(x, state.cluster_centers_)
labels_ = jnp.argmin(dist, axis=1)
def calc_cluster_centers(cluster_idx: jax.Array, random_key: RNGKey) -> jax.Array:
cluster_mask = jnp.logical_and(labels_ == cluster_idx, mask)
cluster_mask_sum = jnp.sum(cluster_mask)
cluster_center = jnp.where(
cluster_mask_sum > 0,
jnp.sum(x * cluster_mask[:, jnp.newaxis], axis=0) / cluster_mask_sum,
jax.random.choice(random_key, x),
)
return cluster_center
random_key_, subkey = jax.random.split(state.random_key_)
keys = jax.random.split(subkey, state.cluster_centers_.shape[0])
calc_cluster_centers_fn = jax.vmap(calc_cluster_centers)
cluster_centers_ = calc_cluster_centers_fn(
jnp.arange(state.cluster_centers_.shape[0]), keys
)
return state.replace(
cluster_centers_=cluster_centers_,
labels_=labels_,
diff_=jnp.sum(jnp.square(cluster_centers_ - state.cluster_centers_)),
n_iter_=state.n_iter_ + 1,
random_key_=random_key_,
)
def _kmeans_single(self, x: jax.Array, mask: jax.Array, state: _KMeansStateT) -> _KMeansStateT:
@jax_jit
def cond_fn(state: _KMeansStateT) -> jax.Array:
_log.debug('Compiling cond_fn...')
return jnp.logical_and(state.n_iter_ < self.max_iter, state.diff_ > self.tol)
@partial(jax_jit, donate_argnums=0)
def body_fn(state: _KMeansStateT) -> _KMeansStateT:
_log.debug('Compiling body_fn...')
return self._step(state, x, mask)
state = jax.lax.while_loop(
cond_fn,
body_fn,
state,
)
state = state.replace(
inertia_=jnp.dot(
jnp.sum(jnp.square(x - state.cluster_centers_[state.labels_]), axis=-1),
mask,
)
)
return state
def _kmeans_random_single(self, random_key: RNGKey, x: jax.Array, mask: jax.Array):
p = mask / jnp.sum(mask)
random_key, subkey = jax.random.split(random_key)
idx = jax.random.choice(subkey, x.shape[0], (self.k,), replace=False, p=p)
cluster_centers_ = x[idx]
return self._kmeans_single(x, mask, self.StateType(
cluster_centers_=cluster_centers_,
labels_=jnp.zeros(x.shape[0], dtype=jnp.int32),
diff_=jnp.asarray(jnp.inf, dtype=jnp.float32),
n_iter_=jnp.asarray(0, dtype=jnp.int32),
inertia_=jnp.asarray(0, dtype=jnp.float32),
random_key_=random_key,
))
@staticmethod
@partial(jax_jit, static_argnames='n_local_trials', donate_argnums=(0, 1, 2))
def _kmeans_pp_step(
closest_dist_sq: jax.Array,
current_pot: jax.Array,
x: jax.Array,
n_local_trials: int,
random_key: RNGKey,
) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
_log.debug('Compiling _kmeans_pp_step...')
candidate_ids = jax.random.choice(
random_key, x.shape[0], (n_local_trials,), p=closest_dist_sq / current_pot
)
distance_to_candidates: jax.Array = F.distance(x[candidate_ids], x)
distance_to_candidates = jnp.minimum(closest_dist_sq, distance_to_candidates)
candidates_pot = distance_to_candidates.sum(axis=1)
best_candidate = candidates_pot.argmin()
current_pot = candidates_pot[best_candidate]
closest_dist_sq = distance_to_candidates[best_candidate]
best_candidate = candidate_ids[best_candidate]
return (closest_dist_sq, current_pot), x[best_candidate]
def _kmeans_pp_single(
self,
random_key: RNGKey,
x: jax.Array,
mask: jax.Array,
) -> KMeansState:
random_key, first_key, others_key = jax.random.split(random_key, 3)
keys = jax.random.split(others_key, self.k - 1)
p = mask / jnp.sum(mask)
idx = jax.random.choice(first_key, x.shape[0], p=p)
closest_dist_sq: jax.Array = F.distance(x, x[idx]) * mask
@partial(jax_jit, donate_argnums=0)
def body_fn(
carry: tuple[jax.Array, jax.Array],
random_key: RNGKey,
) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
_log.debug('Compiling body_fn...')
assert self.n_local_trials is not None
return self._kmeans_pp_step(*carry, x, self.n_local_trials, random_key)
_, cluster_centers_ = lax_scan(
body_fn,
(closest_dist_sq, closest_dist_sq.sum()),
keys,
)
cluster_centers_ = jnp.concatenate((x[0][jnp.newaxis], cluster_centers_))
return self._kmeans_single(x, mask, self.StateType(
cluster_centers_=cluster_centers_,
labels_=jnp.zeros(x.shape[0], dtype=jnp.int32),
diff_=jnp.asarray(jnp.inf, dtype=jnp.float32),
n_iter_=jnp.asarray(0, dtype=jnp.int32),
inertia_=jnp.asarray(0, dtype=jnp.float32),
random_key_=random_key,
))
def fit(
self,
random_key: RNGKey,
x: jax.Array,
mask: Optional[jax.Array] = None,
) -> KMeansState:
assert x.shape[0] >= self.k
n_init = self.n_init
match self.init:
case 'random':
if n_init == 'auto':
n_init = 10
kmeans_fn = self._kmeans_random_single
case 'k-means++':
if n_init == 'auto':
n_init = 1
if self.n_local_trials is None:
self.n_local_trials = 2 + int(np.log(self.k))
kmeans_fn = self._kmeans_pp_single
case _:
raise NotImplementedError(f"No such method of initialization: '{self.init}'")
assert isinstance(n_init, int)
if mask is None:
mask = jnp.ones(x.shape[0])
kmeans_fn = partial(kmeans_fn, x=x, mask=mask)
keys = jax.random.split(random_key, n_init)
states = jax.vmap(kmeans_fn)(keys)
_log.debug('shape(states) = %s', tjnp.shape(states))
idx = jnp.argmin(states.inertia_)
state = tjnp.getitem(states, idx)
_log.debug('shape(state) = %s', tjnp.shape(state))
return state