@@ -626,7 +626,7 @@ def to_LRCGeometry(
626
626
rank : int = 0 ,
627
627
tol : float = 1e-2 ,
628
628
rng : Optional [jax .Array ] = None ,
629
- scale : float = 1.
629
+ scale : float = 1.0
630
630
) -> "low_rank.LRCGeometry" :
631
631
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
632
632
@@ -673,8 +673,8 @@ def to_LRCGeometry(
673
673
i_star = jax .random .randint (rng1 , shape = (), minval = 0 , maxval = n )
674
674
j_star = jax .random .randint (rng2 , shape = (), minval = 0 , maxval = m )
675
675
676
- ci_star = self .subset (i_star , None ).cost_matrix .ravel () ** 2 # (m,)
677
- cj_star = self .subset (None , j_star ).cost_matrix .ravel () ** 2 # (n,)
676
+ ci_star = self .subset ([ i_star ] , None ).cost_matrix .ravel () ** 2 # (m,)
677
+ cj_star = self .subset (None , [ j_star ] ).cost_matrix .ravel () ** 2 # (n,)
678
678
679
679
p_row = cj_star + ci_star [j_star ] + jnp .mean (ci_star ) # (n,)
680
680
p_row /= jnp .sum (p_row )
@@ -697,7 +697,7 @@ def to_LRCGeometry(
697
697
_ , d , v = jnp .linalg .svd (U .T @ U ) # (k,), (k, k)
698
698
v = v .T / jnp .sqrt (d )[None , :]
699
699
700
- inv_scale = (1. / jnp .sqrt (n_subset ))
700
+ inv_scale = (1.0 / jnp .sqrt (n_subset ))
701
701
col_ixs = jax .random .choice (rng5 , m , shape = (n_subset ,)) # (n_subset,)
702
702
703
703
# (n, n_subset)
@@ -740,9 +740,9 @@ def subset_fn(
740
740
if arr is None :
741
741
return None
742
742
if src_ixs is not None :
743
- arr = arr [jnp . atleast_1d ( src_ixs ) ]
743
+ arr = arr [src_ixs , ... ]
744
744
if tgt_ixs is not None :
745
- arr = arr [:, jnp . atleast_1d ( tgt_ixs ) ]
745
+ arr = arr [:, tgt_ixs ]
746
746
return arr # noqa: RET504
747
747
748
748
return self ._mask_subset_helper (
@@ -757,7 +757,7 @@ def mask(
757
757
self ,
758
758
src_mask : Optional [jnp .ndarray ],
759
759
tgt_mask : Optional [jnp .ndarray ],
760
- mask_value : float = 0. ,
760
+ mask_value : float = 0.0 ,
761
761
) -> "Geometry" :
762
762
"""Mask rows or columns of a geometry.
763
763
@@ -855,7 +855,7 @@ def dtype(self) -> jnp.dtype:
855
855
self ._kernel_matrix if self ._cost_matrix is None else self ._cost_matrix
856
856
).dtype
857
857
858
- def _masked_geom (self , mask_value : float = 0. ) -> "Geometry" :
858
+ def _masked_geom (self , mask_value : float = 0.0 ) -> "Geometry" :
859
859
"""Mask geometry based on :attr:`src_mask` and :attr:`tgt_mask`."""
860
860
src_mask , tgt_mask = self .src_mask , self .tgt_mask
861
861
if src_mask is None and tgt_mask is None :
@@ -877,12 +877,11 @@ def _m_normed_ones(self) -> jnp.ndarray:
877
877
return arr / jnp .sum (arr )
878
878
879
879
@staticmethod
880
- def _normalize_mask (mask : Optional [Union [ int , jnp .ndarray ] ],
880
+ def _normalize_mask (mask : Optional [jnp .ndarray ],
881
881
size : int ) -> Optional [jnp .ndarray ]:
882
882
"""Convert array of indices to a boolean mask."""
883
883
if mask is None :
884
884
return None
885
- mask = jnp .atleast_1d (mask )
886
885
if not jnp .issubdtype (mask , (bool , jnp .bool_ )):
887
886
mask = jnp .isin (jnp .arange (size ), mask )
888
887
assert mask .shape == (size ,)
0 commit comments