Skip to content

Commit c8edbd3

Browse files
committed
Fix scalar subsetting
1 parent b0ed993 commit c8edbd3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/ott/geometry/geometry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,8 @@ def to_LRCGeometry(
673673
i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n)
674674
j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m)
675675

676-
ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,)
677-
cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,)
676+
ci_star = self.subset([i_star], None).cost_matrix.ravel() ** 2 # (m,)
677+
cj_star = self.subset(None, [j_star]).cost_matrix.ravel() ** 2 # (n,)
678678

679679
p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
680680
p_row /= jnp.sum(p_row)

src/ott/tools/k_means.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState:
127127
rng, next_rng = jax.random.split(rng, 2)
128128
ix = jax.random.choice(rng, jnp.arange(geom.shape[0]), shape=())
129129
centroids = jnp.full((k, geom.cost_rank), jnp.inf).at[0].set(geom.x[ix])
130-
dists = geom.subset(ix, None).cost_matrix[0]
130+
dists = geom.subset([ix], None).cost_matrix[0]
131131
return KPPState(rng=next_rng, centroids=centroids, centroid_dists=dists)
132132

133133
def body_fn(

0 commit comments

Comments
 (0)