Skip to content

Commit 51a658a

Browse files
authored
Add sparse Chebyshev approximation (#502)
* wip bcoo geodesic * add sparse laplacian * add test & pass test * typo & fix type * appease code formatter * fmt * fix laplacian type * norm lap with elementwise multiplication * unify tests * fix sparse scan & sinkhorn test * fmt * rm mv to sparse since `@jesp.sparsify` * rm sparsify wrapper and fix type * fix type & mv fn & test memory
1 parent da704fc commit 51a658a

File tree

3 files changed

+64
-18
lines changed

3 files changed

+64
-18
lines changed

src/ott/geometry/geodesic.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Optional, Sequence, Tuple
14+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
1515

1616
import jax
1717
import jax.experimental.sparse as jesp
@@ -22,10 +22,11 @@
2222
from ott import utils
2323
from ott.geometry import geometry
2424
from ott.math import utils as mu
25-
from ott.types import Array_g
2625

2726
__all__ = ["Geodesic"]
2827

28+
Array_g = Union[jnp.ndarray, jesp.BCOO]
29+
2930

3031
@jax.tree_util.register_pytree_node_class
3132
class Geodesic(geometry.Geometry):
@@ -106,13 +107,10 @@ def from_graph(
106107
if t is None:
107108
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2
108109

109-
degree = jnp.sum(G, axis=1)
110-
laplacian = jnp.diag(degree) - G
111-
if normalize:
112-
inv_sqrt_deg = jnp.diag(
113-
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
114-
)
115-
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg
110+
if isinstance(G, jesp.BCOO):
111+
laplacian = compute_sparse_laplacian(G, normalize)
112+
else:
113+
laplacian = compute_dense_laplacian(G, normalize)
116114

117115
if eigval is None:
118116
eigval = compute_largest_eigenvalue(laplacian, rng)
@@ -220,6 +218,33 @@ def tree_unflatten( # noqa: D102
220218
return cls(*children, **aux_data)
221219

222220

221+
def normalize_laplacian(laplacian: Array_g, degree: jnp.ndarray) -> Array_g:
222+
inv_sqrt_deg = jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
223+
return inv_sqrt_deg[:, None] * laplacian * inv_sqrt_deg[None, :]
224+
225+
226+
def compute_dense_laplacian(
227+
G: jnp.ndarray, normalize: bool = False
228+
) -> jnp.ndarray:
229+
degree = jnp.sum(G, axis=1)
230+
laplacian = jnp.diag(degree) - G
231+
if normalize:
232+
laplacian = normalize_laplacian(laplacian, degree)
233+
return laplacian
234+
235+
236+
def compute_sparse_laplacian(
237+
G: jesp.BCOO, normalize: bool = False
238+
) -> jesp.BCOO:
239+
n, _ = G.shape
240+
data_degree, ixs = G.sum(1).todense(), jnp.arange(n)
241+
degree = jesp.BCOO((data_degree, jnp.c_[ixs, ixs]), shape=(n, n))
242+
laplacian = degree - G
243+
if normalize:
244+
laplacian = normalize_laplacian(laplacian, data_degree)
245+
return laplacian
246+
247+
223248
def compute_largest_eigenvalue(
224249
laplacian_matrix: jnp.ndarray,
225250
rng: jax.Array,
@@ -242,7 +267,7 @@ def compute_largest_eigenvalue(
242267

243268

244269
def expm_multiply(
245-
L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
270+
L: Array_g, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
246271
) -> jnp.ndarray:
247272

248273
def body(carry, c):

src/ott/types.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Protocol, Union
14+
from typing import Protocol
1515

16-
import jax.experimental.sparse as jesp
1716
import jax.numpy as jnp
1817

19-
__all__ = ["Transport", "Array_g"]
18+
__all__ = ["Transport"]
2019

2120
# TODO(michalk8): introduce additional types here
2221

23-
# Either a dense or sparse array.
24-
Array_g = Union[jnp.ndarray, jesp.BCOO]
25-
2622

2723
class Transport(Protocol):
2824
"""Interface for the solution of a transport problem.

tests/geometry/geodesic_test.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Optional, Union
1515

1616
import jax
17+
import jax.experimental.sparse as jesp
1718
import jax.numpy as jnp
1819
import networkx as nx
1920
import numpy as np
@@ -29,6 +30,7 @@ def random_graph(
2930
n: int,
3031
p: float = 0.3,
3132
seed: Optional[int] = 0,
33+
is_sparse: bool = False,
3234
*,
3335
return_laplacian: bool = False,
3436
directed: bool = False,
@@ -45,6 +47,8 @@ def random_graph(
4547
G
4648
) if return_laplacian else nx.linalg.adjacency_matrix(G)
4749

50+
if is_sparse:
51+
return jesp.BCOO.from_scipy_sparse(G)
4852
return jnp.asarray(G.toarray())
4953

5054

@@ -196,7 +200,8 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
196200
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-5)
197201

198202
@pytest.mark.fast.with_args(jit=[False, True], only_fast=0)
199-
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool):
203+
@pytest.mark.parametrize("is_sparse", [True, False])
204+
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool, is_sparse: bool):
200205

201206
def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
202207
solver = sinkhorn.Sinkhorn(lse_mode=False)
@@ -208,6 +213,8 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
208213
x = jax.random.normal(rng, (n,))
209214

210215
gt_geom = gt_geometry(G, epsilon=eps)
216+
if is_sparse:
217+
G = jesp.BCOO.fromdense(G)
211218
graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0)
212219

213220
fn = jax.jit(callback) if jit else callback
@@ -257,11 +264,29 @@ def callback(geom: geodesic.Geodesic) -> float:
257264
@pytest.mark.parametrize("normalize", [False, True])
258265
@pytest.mark.parametrize("t", [5, 10, 50])
259266
@pytest.mark.parametrize("order", [20, 30, 40])
260-
def test_heat_approx(self, normalize: bool, t: float, order: int):
267+
@pytest.mark.parametrize("is_sparse", [True, False])
268+
def test_heat_approx(
269+
self, normalize: bool, t: float, order: int, is_sparse: bool
270+
):
261271
G = random_graph(20, p=0.5)
262272
exact = exact_heat_kernel(G, normalize=normalize, t=t)
273+
if is_sparse:
274+
G = jesp.BCOO.fromdense(G)
263275
geom = geodesic.Geodesic.from_graph(
264276
G, t=t, order=order, normalize=normalize
265277
)
266278
approx = geom.apply_kernel(jnp.eye(G.shape[0]))
279+
267280
np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)
281+
282+
@pytest.mark.limit_memory("150 MB")
283+
def test_sparse_geo_memory(self, rng: jax.Array):
284+
n = 10_000
285+
G = random_graph(n, p=0.001, is_sparse=True)
286+
x = jax.random.normal(rng, (n,))
287+
288+
graph_geom = geodesic.Geodesic.from_graph(G, t=1.0, order=10)
289+
290+
out = jax.jit(graph_geom.apply_kernel)(x)
291+
292+
np.testing.assert_array_equal(jnp.isfinite(out), True)

0 commit comments

Comments
 (0)