Skip to content

Commit 00d5483

Browse files
authored
Potentials as tuple (#488)
* Add support for `.potentials` for `SinkhornOutput` * Remove explicit `dtype` casting * Add leading or trailing 0s * Fix dtype in test * Fix typo * Remove `jnp.atleast_1d` in subsetting * Fix scalar subsetting * Fix `flax=0.8.1` docs typos in docs linter * Fix more pet-peeves
1 parent 1e4b0c0 commit 00d5483

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+446
-464
lines changed

docs/_templates/autosummary/class.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
.. autosummary::
1111
:toctree: .
1212
{% for item in methods %}
13-
{%- if item not in ['__init__', 'tree_flatten', 'tree_unflatten', 'bind', 'tabulate'] %}
13+
{%- if item not in ['__init__', 'tree_flatten', 'tree_unflatten', 'bind', 'tabulate', 'module_paths'] %}
1414
~{{ name }}.{{ item }}
1515
{%- endif %}
1616
{%- endfor %}

docs/conf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@
104104
spelling_warning = True
105105
spelling_word_list_filename = ["spelling/technical.txt", "spelling/misc.txt"]
106106
spelling_add_pypi_package_names = True
107-
# flax misspelled words; `flax.linen.Module.bind` is ignored in `class.rst`
108-
# because of indentation error that cannot be suppressed
107+
# flax misspelled words; `flax.linen.Module.{bind,module_paths}` is ignored in
108+
# the `class.rst` because of indentation error that cannot be suppressed
109109
spelling_exclude_patterns = [
110110
"bibliography.rst",
111111
"**setup.rst",

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ markers = [
120120
"fast: Mark tests as fast.",
121121
]
122122
filterwarnings = [
123+
"ignore:\\n*.*scipy.sparse array",
123124
"ignore:jax.random.KeyArray is deprecated:DeprecationWarning",
125+
"ignore:.*jax.config:DeprecationWarning",
126+
"ignore:jax.core.Shape is deprecated:DeprecationWarning:chex",
124127
]
125128

126129
[tool.coverage.run]

src/ott/geometry/costs.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
284284

285285
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
286286
"""Compute minus twice the dot-product between vectors."""
287-
return -2. * jnp.vdot(x, y)
287+
return -2.0 * jnp.vdot(x, y)
288288

289289
def h(self, z: jnp.ndarray) -> float: # noqa: D102
290290
return jnp.sum(z ** 2)
@@ -806,7 +806,6 @@ def covariance_fixpoint_iter(
806806
min_iterations = kwargs.pop("min_iterations", 1)
807807
max_iterations = kwargs.pop("max_iterations", 100)
808808
inner_iterations = kwargs.pop("inner_iterations", 5)
809-
dtype = covs.dtype
810809

811810
@functools.partial(jax.vmap, in_axes=[None, 0, 0])
812811
def scale_covariances(
@@ -838,10 +837,7 @@ def body_fn(
838837

839838
def init_state() -> Tuple[jnp.ndarray, float]:
840839
cov_init = jnp.eye(self._dimension)
841-
diffs = -jnp.ones(
842-
(np.ceil(max_iterations / inner_iterations).astype(int),),
843-
dtype=dtype
844-
)
840+
diffs = -jnp.ones(math.ceil(max_iterations / inner_iterations))
845841
return cov_init, diffs
846842

847843
cov, diffs = fixed_point_loop.fixpoint_iter(
@@ -990,7 +986,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
990986
diff_means = mean_x - mean_y
991987

992988
# Identity matrix of suitable size
993-
iden = jnp.eye(self._dimension, dtype=x.dtype)
989+
iden = jnp.eye(self._dimension)
994990

995991
# Creates matrices needed in the computation
996992
tilde_a = 0.5 * gam * (iden - lam * jnp.linalg.inv(cov_x + lam * iden))

src/ott/geometry/geodesic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
t: float = 1e-3,
5757
**kwargs: Any
5858
):
59-
super().__init__(epsilon=1., **kwargs)
59+
super().__init__(epsilon=1.0, **kwargs)
6060
self.scaled_laplacian = scaled_laplacian
6161
self.eigval = eigval
6262
self.chebyshev_coeffs = chebyshev_coeffs
@@ -104,7 +104,7 @@ def from_graph(
104104
if directed:
105105
G = G + G.T
106106
if t is None:
107-
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2.0
107+
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2
108108

109109
degree = jnp.sum(G, axis=1)
110110
laplacian = jnp.diag(degree) - G

src/ott/geometry/geometry.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def to_LRCGeometry(
626626
rank: int = 0,
627627
tol: float = 1e-2,
628628
rng: Optional[jax.Array] = None,
629-
scale: float = 1.
629+
scale: float = 1.0
630630
) -> "low_rank.LRCGeometry":
631631
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
632632
@@ -673,8 +673,8 @@ def to_LRCGeometry(
673673
i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n)
674674
j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m)
675675

676-
ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,)
677-
cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,)
676+
ci_star = self.subset([i_star], None).cost_matrix.ravel() ** 2 # (m,)
677+
cj_star = self.subset(None, [j_star]).cost_matrix.ravel() ** 2 # (n,)
678678

679679
p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
680680
p_row /= jnp.sum(p_row)
@@ -697,7 +697,7 @@ def to_LRCGeometry(
697697
_, d, v = jnp.linalg.svd(U.T @ U) # (k,), (k, k)
698698
v = v.T / jnp.sqrt(d)[None, :]
699699

700-
inv_scale = (1. / jnp.sqrt(n_subset))
700+
inv_scale = (1.0 / jnp.sqrt(n_subset))
701701
col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,)
702702

703703
# (n, n_subset)
@@ -740,9 +740,9 @@ def subset_fn(
740740
if arr is None:
741741
return None
742742
if src_ixs is not None:
743-
arr = arr[jnp.atleast_1d(src_ixs)]
743+
arr = arr[src_ixs, ...]
744744
if tgt_ixs is not None:
745-
arr = arr[:, jnp.atleast_1d(tgt_ixs)]
745+
arr = arr[:, tgt_ixs]
746746
return arr # noqa: RET504
747747

748748
return self._mask_subset_helper(
@@ -757,7 +757,7 @@ def mask(
757757
self,
758758
src_mask: Optional[jnp.ndarray],
759759
tgt_mask: Optional[jnp.ndarray],
760-
mask_value: float = 0.,
760+
mask_value: float = 0.0,
761761
) -> "Geometry":
762762
"""Mask rows or columns of a geometry.
763763
@@ -855,7 +855,7 @@ def dtype(self) -> jnp.dtype:
855855
self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
856856
).dtype
857857

858-
def _masked_geom(self, mask_value: float = 0.) -> "Geometry":
858+
def _masked_geom(self, mask_value: float = 0.0) -> "Geometry":
859859
"""Mask geometry based on :attr:`src_mask` and :attr:`tgt_mask`."""
860860
src_mask, tgt_mask = self.src_mask, self.tgt_mask
861861
if src_mask is None and tgt_mask is None:
@@ -877,12 +877,11 @@ def _m_normed_ones(self) -> jnp.ndarray:
877877
return arr / jnp.sum(arr)
878878

879879
@staticmethod
880-
def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]],
880+
def _normalize_mask(mask: Optional[jnp.ndarray],
881881
size: int) -> Optional[jnp.ndarray]:
882882
"""Convert array of indices to a boolean mask."""
883883
if mask is None:
884884
return None
885-
mask = jnp.atleast_1d(mask)
886885
if not jnp.issubdtype(mask, (bool, jnp.bool_)):
887886
mask = jnp.isin(jnp.arange(size), mask)
888887
assert mask.shape == (size,)

src/ott/geometry/graph.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
tol: float = -1.0,
5757
**kwargs: Any
5858
):
59-
super().__init__(epsilon=1., **kwargs)
59+
super().__init__(epsilon=1.0, **kwargs)
6060
self.laplacian = laplacian
6161
self.t = t
6262
self.n_steps = n_steps
@@ -107,7 +107,7 @@ def from_graph(
107107
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg
108108

109109
if t is None:
110-
t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2
110+
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2
111111

112112
return cls(laplacian, t=t, **kwargs)
113113

@@ -162,7 +162,7 @@ def body_fn(
162162
# axis we can ignore since the matrix is symmetric
163163
del eps, axis
164164

165-
force_scan = self.tol < 0.
165+
force_scan = self.tol < 0.0
166166
fixpoint_fn = (
167167
fixed_point_loop.fixpoint_iter
168168
if force_scan else fixed_point_loop.fixpoint_iter_backprop
@@ -204,9 +204,9 @@ def cost_matrix(self) -> jnp.ndarray: # noqa: D102
204204
def _scale(self) -> float:
205205
"""Constant used to scale the Laplacian."""
206206
if self.numerical_scheme == "backward_euler":
207-
return self.t / (4. * self.n_steps)
207+
return self.t / (4.0 * self.n_steps)
208208
if self.numerical_scheme == "crank_nicolson":
209-
return self.t / (2. * self.n_steps)
209+
return self.t / (2.0 * self.n_steps)
210210
raise NotImplementedError(
211211
f"Numerical scheme `{self.numerical_scheme}` is not implemented."
212212
)

src/ott/geometry/grid.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def mask(
320320
self,
321321
src_mask: Optional[jnp.ndarray],
322322
tgt_mask: Optional[jnp.ndarray],
323-
mask_value: float = 0.,
323+
mask_value: float = 0.0,
324324
) -> NoReturn:
325325
"""Not implemented."""
326326
raise NotImplementedError("Masking is not implemented for grids.")

src/ott/geometry/low_rank.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def subset_fn(
251251
arr: Optional[jnp.ndarray],
252252
ixs: Optional[jnp.ndarray],
253253
) -> jnp.ndarray:
254-
return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)]
254+
return arr if arr is None or ixs is None else arr[ixs, ...]
255255

256256
return self._mask_subset_helper(
257257
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs
@@ -261,7 +261,7 @@ def mask( # noqa: D102
261261
self,
262262
src_mask: Optional[jnp.ndarray],
263263
tgt_mask: Optional[jnp.ndarray],
264-
mask_value: float = 0.,
264+
mask_value: float = 0.0,
265265
) -> "LRCGeometry":
266266

267267
def mask_fn(

src/ott/geometry/pointcloud.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def __init__(
8080
def _norm_x(self) -> Union[float, jnp.ndarray]:
8181
if self._axis_norm == 0:
8282
return self.cost_fn.norm(self.x)
83-
return 0.
83+
return 0.0
8484

8585
@property
8686
def _norm_y(self) -> Union[float, jnp.ndarray]:
8787
if self._axis_norm == 0:
8888
return self.cost_fn.norm(self.y)
89-
return 0.
89+
return 0.0
9090

9191
@property
9292
def can_LRC(self): # noqa: D102
@@ -583,7 +583,7 @@ def _cosine_to_sqeucl(self) -> "PointCloud":
583583
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
584584
y = y / jnp.linalg.norm(y, axis=-1, keepdims=True)
585585
# TODO(michalk8): find a better way
586-
aux_data["scale_cost"] = 2. / self.inv_scale_cost
586+
aux_data["scale_cost"] = 2.0 / self.inv_scale_cost
587587
cost_fn = costs.SqEuclidean()
588588
return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn])
589589

@@ -648,7 +648,7 @@ def subset_fn(
648648
arr: Optional[jnp.ndarray],
649649
ixs: Optional[jnp.ndarray],
650650
) -> jnp.ndarray:
651-
return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)]
651+
return arr if arr is None or ixs is None else arr[ixs, ...]
652652

653653
return self._mask_subset_helper(
654654
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs
@@ -658,7 +658,7 @@ def mask( # noqa: D102
658658
self,
659659
src_mask: Optional[jnp.ndarray],
660660
tgt_mask: Optional[jnp.ndarray],
661-
mask_value: float = 0.,
661+
mask_value: float = 0.0,
662662
) -> "PointCloud":
663663

664664
def mask_fn(

src/ott/initializers/linear/initializers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def __call__(
105105
), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`."
106106

107107
# cancel dual variables for zero weights
108-
a = jnp.where(ot_prob.a > 0., a, -jnp.inf if lse_mode else 0.)
109-
b = jnp.where(ot_prob.b > 0., b, -jnp.inf if lse_mode else 0.)
108+
a = jnp.where(ot_prob.a > 0.0, a, -jnp.inf if lse_mode else 0.0)
109+
b = jnp.where(ot_prob.b > 0.0, b, -jnp.inf if lse_mode else 0.0)
110110

111111
return a, b
112112

@@ -339,10 +339,10 @@ def init_dual_a( # noqa: D102
339339

340340
# subsample
341341
sub_x = jax.random.choice(
342-
key=rng_x, a=x, shape=(self.subsample_n_x,), replace=True, p=a, axis=0
342+
rng_x, a=x, shape=(self.subsample_n_x,), replace=True, p=a, axis=0
343343
)
344344
sub_y = jax.random.choice(
345-
key=rng_y, a=y, shape=(self.subsample_n_y,), replace=True, p=b, axis=0
345+
rng_y, a=y, shape=(self.subsample_n_y,), replace=True, p=b, axis=0
346346
)
347347

348348
# create subsampled point cloud geometry

src/ott/initializers/linear/initializers_lr.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def init_g( # noqa: D102
262262
**kwargs: Any,
263263
) -> jnp.ndarray:
264264
del kwargs
265-
init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1.
265+
init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1.0
266266
return init_g / jnp.sum(init_g)
267267

268268

@@ -300,7 +300,7 @@ def _compute_factor(
300300
y = (marginal - lambda_1 * x) / (1.0 - lambda_1)
301301

302302
return ((lambda_1 * x[:, None] @ g1.reshape(1, -1)) +
303-
((1 - lambda_1) * y[:, None] @ g2.reshape(1, -1)))
303+
((1.0 - lambda_1) * y[:, None] @ g2.reshape(1, -1)))
304304

305305
def init_q( # noqa: D102
306306
self,
@@ -477,7 +477,7 @@ class GeneralizedKMeansInitializer(KMeansInitializer):
477477
def __init__(
478478
self,
479479
rank: int,
480-
gamma: float = 10.,
480+
gamma: float = 10.0,
481481
min_iterations: int = 0,
482482
max_iterations: int = 100,
483483
inner_iterations: int = 10,
@@ -523,7 +523,7 @@ def _compute_factor(
523523

524524
def init_fn() -> GeneralizedKMeansInitializer.State:
525525
n = geom.shape[0]
526-
factor = jnp.abs(jax.random.normal(rng, (n, self.rank))) + 1. # (n, r)
526+
factor = jnp.abs(jax.random.normal(rng, (n, self.rank))) + 1.0 # (n, r)
527527
factor *= consts.marginal[:, None] / jnp.sum(
528528
factor, axis=1, keepdims=True
529529
)
@@ -586,7 +586,7 @@ def body_fn(
586586

587587
norm = jnp.max(jnp.abs(grad)) ** 2
588588
gamma = consts.gamma / norm
589-
eps = 1. / gamma
589+
eps = 1.0 / gamma
590590

591591
cost = grad - eps * mu.safe_log(state.factor) # (n, r)
592592
cost = geometry.Geometry(

src/ott/math/matrix_square_root.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import math
1516
from typing import Tuple
1617

1718
import jax
1819
import jax.numpy as jnp
19-
import numpy as np
2020

2121
from ott.math import fixed_point_loop
2222

@@ -87,7 +87,7 @@ def body_fn(iteration, const, state, compute_error):
8787
y = 1.5 * y - jnp.matmul(y, w)
8888
z = 1.5 * z - jnp.matmul(w, z)
8989

90-
err = jnp.where(compute_error, new_err(x, norm_x, y), np.inf)
90+
err = jnp.where(compute_error, new_err(x, norm_x, y), jnp.inf)
9191

9292
errors = errors.at[iteration // inner_iterations].set(err)
9393

@@ -98,13 +98,11 @@ def new_err(x, norm_x, y):
9898
norm_fn = functools.partial(jnp.linalg.norm, axis=(-2, -1))
9999
return jnp.max(norm_fn(res) / norm_fn(x))
100100

101-
dtype = x.dtype
102101
y = x / norm_x
103-
z = jnp.eye(dimension, dtype=dtype)
102+
z = jnp.eye(dimension)
104103
if jnp.ndim(x) > 2:
105104
z = jnp.tile(z, list(x.shape[:-2]) + [1, 1])
106-
errors = -jnp.ones((np.ceil(max_iterations / inner_iterations).astype(int),),
107-
dtype=dtype)
105+
errors = -jnp.ones(math.ceil(max_iterations / inner_iterations))
108106
state = (errors, y, z)
109107
const = (x, threshold)
110108
errors, y, z = fixed_point_loop.fixpoint_iter_backprop(
@@ -139,7 +137,7 @@ def solve_sylvester_bartels_stewart(
139137
)
140138
# The solution in the transformed space will in general be complex, too.
141139
y = jnp.zeros(a.shape[:-2] + (m, n)) + 0j
142-
idx = jnp.arange(m, dtype=jnp.int32)
140+
idx = jnp.arange(m)
143141
for j in range(n):
144142
lhs = r.at[..., idx, idx].add(-s[..., j:j + 1, j])
145143
rhs = d[..., j] + jnp.matmul(y[..., :j], s[..., :j, j:j + 1])[..., 0]

src/ott/math/unbalanced_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def diag_jacobian_of_marginal_fit(
7575
a vector of the same size as c or h.
7676
"""
7777
if tau == 1.0:
78-
return 0.
78+
return 0.0
7979

8080
r = rho(epsilon, tau)
8181
# here no minus sign because we are taking derivative w.r.t -h
@@ -87,4 +87,4 @@ def diag_jacobian_of_marginal_fit(
8787

8888

8989
def rho(epsilon: float, tau: float) -> float: # noqa: D103
90-
return (epsilon * tau) / (1. - tau)
90+
return (epsilon * tau) / (1.0 - tau)

0 commit comments

Comments
 (0)