14
14
from typing import Optional , Union
15
15
16
16
import jax
17
+ import jax .experimental .sparse as jesp
17
18
import jax .numpy as jnp
18
19
import networkx as nx
19
20
import numpy as np
@@ -29,6 +30,7 @@ def random_graph(
29
30
n : int ,
30
31
p : float = 0.3 ,
31
32
seed : Optional [int ] = 0 ,
33
+ is_sparse : bool = False ,
32
34
* ,
33
35
return_laplacian : bool = False ,
34
36
directed : bool = False ,
@@ -45,6 +47,8 @@ def random_graph(
45
47
G
46
48
) if return_laplacian else nx .linalg .adjacency_matrix (G )
47
49
50
+ if is_sparse :
51
+ return jesp .BCOO .from_scipy_sparse (G )
48
52
return jnp .asarray (G .toarray ())
49
53
50
54
@@ -196,7 +200,8 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
196
200
np .testing .assert_allclose (actual , expected , rtol = 1e-5 , atol = 1e-5 )
197
201
198
202
@pytest .mark .fast .with_args (jit = [False , True ], only_fast = 0 )
199
- def test_geo_sinkhorn (self , rng : jax .Array , jit : bool ):
203
+ @pytest .mark .parametrize ("is_sparse" , [True , False ])
204
+ def test_geo_sinkhorn (self , rng : jax .Array , jit : bool , is_sparse : bool ):
200
205
201
206
def callback (geom : geometry .Geometry ) -> sinkhorn .SinkhornOutput :
202
207
solver = sinkhorn .Sinkhorn (lse_mode = False )
@@ -208,6 +213,8 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
208
213
x = jax .random .normal (rng , (n ,))
209
214
210
215
gt_geom = gt_geometry (G , epsilon = eps )
216
+ if is_sparse :
217
+ G = jesp .BCOO .fromdense (G )
211
218
graph_geom = geodesic .Geodesic .from_graph (G , t = eps / 4.0 )
212
219
213
220
fn = jax .jit (callback ) if jit else callback
@@ -257,11 +264,29 @@ def callback(geom: geodesic.Geodesic) -> float:
257
264
@pytest .mark .parametrize ("normalize" , [False , True ])
258
265
@pytest .mark .parametrize ("t" , [5 , 10 , 50 ])
259
266
@pytest .mark .parametrize ("order" , [20 , 30 , 40 ])
260
- def test_heat_approx (self , normalize : bool , t : float , order : int ):
267
+ @pytest .mark .parametrize ("is_sparse" , [True , False ])
268
+ def test_heat_approx (
269
+ self , normalize : bool , t : float , order : int , is_sparse : bool
270
+ ):
261
271
G = random_graph (20 , p = 0.5 )
262
272
exact = exact_heat_kernel (G , normalize = normalize , t = t )
273
+ if is_sparse :
274
+ G = jesp .BCOO .fromdense (G )
263
275
geom = geodesic .Geodesic .from_graph (
264
276
G , t = t , order = order , normalize = normalize
265
277
)
266
278
approx = geom .apply_kernel (jnp .eye (G .shape [0 ]))
279
+
267
280
np .testing .assert_allclose (exact , approx , rtol = 1e-1 , atol = 1e-1 )
281
+
282
+ @pytest .mark .limit_memory ("150 MB" )
283
+ def test_sparse_geo_memory (self , rng : jax .Array ):
284
+ n = 10_000
285
+ G = random_graph (n , p = 0.001 , is_sparse = True )
286
+ x = jax .random .normal (rng , (n ,))
287
+
288
+ graph_geom = geodesic .Geodesic .from_graph (G , t = 1.0 , order = 10 )
289
+
290
+ out = jax .jit (graph_geom .apply_kernel )(x )
291
+
292
+ np .testing .assert_array_equal (jnp .isfinite (out ), True )
0 commit comments