Skip to content

Commit 830c1d6

Browse files
committed
add test & pass test
1 parent 7bcec65 commit 830c1d6

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

src/ott/geometry/geodesic.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def compute_sparse_laplacian(G: Array_g, normalize:bool = False) -> Array_g:
4444
data = jnp.where(data > 0., 1. / jnp.sqrt(data), 0.)
4545
degree = jesp.BCOO((data, jnp.c_[ixs, ixs]), shape=(n, n))
4646
if normalize:
47-
inv_sqrt_deg = degree.data # TODO: is MM still faster here ?
48-
laplacian = inv_sqrt_deg[:, None] * G * inv_sqrt_deg[None, :]
47+
id = jesp.BCOO((jnp.ones(n), jnp.c_[ixs, ixs]), shape=(n, n))
48+
laplacian = id - degree @ G @ degree
4949
else:
5050
laplacian = degree - G
5151
return laplacian
@@ -184,10 +184,14 @@ def apply_kernel(
184184
def kernel_matrix(self) -> jnp.ndarray: # noqa: D102
185185
n, _ = self.shape
186186
kernel = self.apply_kernel(jnp.eye(n))
187-
return jax.lax.cond(
188-
jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x,
189-
lambda x: (x + x.T) / 2.0, kernel
190-
)
187+
if isinstance(kernel, jesp.BCOO):
188+
# we symmetrize sparse kernel by default
189+
return (kernel + kernel.T) * 0.5
190+
else:
191+
return jax.lax.cond(
192+
jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x,
193+
lambda x: (x + x.T) / 2.0, kernel
194+
)
191195

192196
@property
193197
def cost_matrix(self) -> jnp.ndarray: # noqa: D102
@@ -266,31 +270,33 @@ def compute_largest_eigenvalue(
266270
)
267271
return eigvals[0]
268272

269-
270273
def expm_multiply(
271274
L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
272275
) -> jnp.ndarray:
273-
274276
# move to sparse matrix
275-
L = jax.experimental.sparse.BCOO.fromdense(L, nse=10)
276-
X = jax.experimental.sparse.BCOO.fromdense(X, nse=1)
277+
is_sparse = isinstance(L, jesp.BCOO)
278+
if is_sparse and not isinstance(X, jesp.BCOO):
279+
X = jax.experimental.sparse.BCOO.fromdense(X, nse=X.shape[0])
280+
277281
def body(carry, c):
278282
T0, T1, Y = carry
279-
T1, Y = T1.sort_indices(), Y.sort_indices()
280283
T2 = (2.0 / eigval) * L @ T1 - 2.0 * T1 - T0
281284
Y = Y + c * T2
282285
return (T1, T2, Y), None
283286

284287
T0 = X
285-
T0.unique_indices = False # FIXME: rm change attribute
286288
Y = 0.5 * coeff[0] * T0
287289
T1 = (1.0 / eigval) * L @ X - T0
288290
Y = Y + coeff[1] * T1
289291

292+
290293
initial_state = (T0, T1, Y)
291-
(_, _, Y), _ = jax.lax.scan(body, initial_state, coeff[2:]) # FIXME scan does not work
292-
# because metadata unique_indices is changing
293-
# for 1st position of carry.
294+
if not is_sparse:
295+
(_, _, Y), _ = jax.lax.scan(body, initial_state, coeff[2:])
296+
else:
297+
# NOTE: the scan not working for this type of scan
298+
for c in coeff[2:]:
299+
(T0, T1, Y), _ = body((T0, T1, Y), c)
294300
return Y
295301

296302

tests/geometry/geodesic_test.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from ott.problems.linear import linear_problem
2626
from ott.solvers.linear import sinkhorn
2727

28-
2928
def random_graph(
3029
n: int,
3130
p: float = 0.3,
@@ -267,24 +266,21 @@ def test_heat_approx(self, normalize: bool, t: float, order: int):
267266
approx = geom.apply_kernel(jnp.eye(G.shape[0]))
268267
np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)
269268

270-
def test_sparse_geodesic(self, rng: jax.Array):
271-
G = random_graph(20, p=0.5)
272-
# G = sparse.BCOO.fromdense(G)
273-
geom = geodesic.Geodesic.from_graph(G, t=1.0)
269+
@pytest.mark.parametrize("normalize", [True, False])
270+
def test_sparse_geodesic(self, normalize):
271+
n = 20
272+
G = random_graph(n, p=0.5)
273+
G_sparse = sparse.BCOO.fromdense(G)
274+
geom = geodesic.Geodesic.from_graph(G_sparse, t=5.0, order=10, normalize=normalize)
274275
kernel_matrix = geom.kernel_matrix
275276

276-
v = jax.random.normal(rng, (n,))
277-
v = v / jnp.linalg.norm(v)
278-
279-
v = jax.device_put(v)
280-
v = sparse.COO.from_numpy(v)
277+
gh_heat_kernel = exact_heat_kernel(G, normalize=normalize, t=5.0)
281278

282-
w = geom.apply_kernel(v, axis=0)
283-
w = w.todense()
284-
w = jnp.asarray(w)
279+
assert isinstance(kernel_matrix, sparse.BCOO)
285280

286-
np.testing.assert_allclose(w, geom.kernel_matrix @ v.todense())
281+
kernel_matrix = kernel_matrix.todense()
282+
np.testing.assert_allclose(kernel_matrix, gh_heat_kernel, rtol=1e-1, atol=1e-1)
287283

288284

289285
if __name__ == "__main__":
290-
pytest.main([__file__])
286+
TestGeodesic().test_sparse_geodesic(jax.random.PRNGKey(0))

0 commit comments

Comments
 (0)