@@ -44,8 +44,8 @@ def compute_sparse_laplacian(G: Array_g, normalize:bool = False) -> Array_g:
44
44
data = jnp .where (data > 0. , 1. / jnp .sqrt (data ), 0. )
45
45
degree = jesp .BCOO ((data , jnp .c_ [ixs , ixs ]), shape = (n , n ))
46
46
if normalize :
47
- inv_sqrt_deg = degree . data # TODO: is MM still faster here ?
48
- laplacian = inv_sqrt_deg [:, None ] * G * inv_sqrt_deg [ None , :]
47
+ id = jesp . BCOO (( jnp . ones ( n ), jnp . c_ [ ixs , ixs ]), shape = ( n , n ))
48
+ laplacian = id - degree @ G @ degree
49
49
else :
50
50
laplacian = degree - G
51
51
return laplacian
@@ -184,10 +184,14 @@ def apply_kernel(
184
184
def kernel_matrix (self ) -> jnp .ndarray : # noqa: D102
185
185
n , _ = self .shape
186
186
kernel = self .apply_kernel (jnp .eye (n ))
187
- return jax .lax .cond (
188
- jnp .allclose (kernel , kernel .T , atol = 1e-8 , rtol = 1e-8 ), lambda x : x ,
189
- lambda x : (x + x .T ) / 2.0 , kernel
190
- )
187
+ if isinstance (kernel , jesp .BCOO ):
188
+ # we symmetrize sparse kernel by default
189
+ return (kernel + kernel .T ) * 0.5
190
+ else :
191
+ return jax .lax .cond (
192
+ jnp .allclose (kernel , kernel .T , atol = 1e-8 , rtol = 1e-8 ), lambda x : x ,
193
+ lambda x : (x + x .T ) / 2.0 , kernel
194
+ )
191
195
192
196
@property
193
197
def cost_matrix (self ) -> jnp .ndarray : # noqa: D102
@@ -266,31 +270,33 @@ def compute_largest_eigenvalue(
266
270
)
267
271
return eigvals [0 ]
268
272
269
-
270
273
def expm_multiply (
271
274
L : jnp .ndarray , X : jnp .ndarray , coeff : jnp .ndarray , eigval : float
272
275
) -> jnp .ndarray :
273
-
274
276
# move to sparse matrix
275
- L = jax .experimental .sparse .BCOO .fromdense (L , nse = 10 )
276
- X = jax .experimental .sparse .BCOO .fromdense (X , nse = 1 )
277
+ is_sparse = isinstance (L , jesp .BCOO )
278
+ if is_sparse and not isinstance (X , jesp .BCOO ):
279
+ X = jax .experimental .sparse .BCOO .fromdense (X , nse = X .shape [0 ])
280
+
277
281
def body (carry , c ):
278
282
T0 , T1 , Y = carry
279
- T1 , Y = T1 .sort_indices (), Y .sort_indices ()
280
283
T2 = (2.0 / eigval ) * L @ T1 - 2.0 * T1 - T0
281
284
Y = Y + c * T2
282
285
return (T1 , T2 , Y ), None
283
286
284
287
T0 = X
285
- T0 .unique_indices = False # FIXME: rm change attribute
286
288
Y = 0.5 * coeff [0 ] * T0
287
289
T1 = (1.0 / eigval ) * L @ X - T0
288
290
Y = Y + coeff [1 ] * T1
289
291
292
+
290
293
initial_state = (T0 , T1 , Y )
291
- (_ , _ , Y ), _ = jax .lax .scan (body , initial_state , coeff [2 :]) # FIXME scan does not work
292
- # because metadata unique_indices is changing
293
- # for 1st position of carry.
294
+ if not is_sparse :
295
+ (_ , _ , Y ), _ = jax .lax .scan (body , initial_state , coeff [2 :])
296
+ else :
297
+ # NOTE: the scan not working for this type of scan
298
+ for c in coeff [2 :]:
299
+ (T0 , T1 , Y ), _ = body ((T0 , T1 , Y ), c )
294
300
return Y
295
301
296
302
0 commit comments