Skip to content

Commit b0ed993

Browse files
committed
Remove jnp.atleast_1d in subsetting
1 parent 17f3be7 commit b0ed993

File tree

6 files changed

+12
-14
lines changed

6 files changed

+12
-14
lines changed

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ markers = [
120120
"fast: Mark tests as fast.",
121121
]
122122
filterwarnings = [
123+
"ignore:\\n*.*scipy.sparse array",
123124
"ignore:jax.random.KeyArray is deprecated:DeprecationWarning",
125+
"ignore:.*jax.config:DeprecationWarning",
126+
"ignore:jax.core.Shape is deprecated:DeprecationWarning:chex",
124127
]
125128

126129
[tool.coverage.run]

src/ott/geometry/geometry.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -740,9 +740,9 @@ def subset_fn(
740740
if arr is None:
741741
return None
742742
if src_ixs is not None:
743-
arr = arr[jnp.atleast_1d(src_ixs)]
743+
arr = arr[src_ixs, ...]
744744
if tgt_ixs is not None:
745-
arr = arr[:, jnp.atleast_1d(tgt_ixs)]
745+
arr = arr[:, tgt_ixs]
746746
return arr # noqa: RET504
747747

748748
return self._mask_subset_helper(
@@ -877,12 +877,11 @@ def _m_normed_ones(self) -> jnp.ndarray:
877877
return arr / jnp.sum(arr)
878878

879879
@staticmethod
880-
def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]],
880+
def _normalize_mask(mask: Optional[jnp.ndarray],
881881
size: int) -> Optional[jnp.ndarray]:
882882
"""Convert array of indices to a boolean mask."""
883883
if mask is None:
884884
return None
885-
mask = jnp.atleast_1d(mask)
886885
if not jnp.issubdtype(mask, (bool, jnp.bool_)):
887886
mask = jnp.isin(jnp.arange(size), mask)
888887
assert mask.shape == (size,)

src/ott/geometry/low_rank.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def subset_fn(
251251
arr: Optional[jnp.ndarray],
252252
ixs: Optional[jnp.ndarray],
253253
) -> jnp.ndarray:
254-
return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)]
254+
return arr if arr is None or ixs is None else arr[ixs, ...]
255255

256256
return self._mask_subset_helper(
257257
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs

src/ott/geometry/pointcloud.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def subset_fn(
648648
arr: Optional[jnp.ndarray],
649649
ixs: Optional[jnp.ndarray],
650650
) -> jnp.ndarray:
651-
return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)]
651+
return arr if arr is None or ixs is None else arr[ixs, ...]
652652

653653
return self._mask_subset_helper(
654654
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs

src/ott/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
def register_pytree_node(cls: type) -> type:
4141
"""Register dataclasses as pytree_nodes."""
4242
cls = dataclasses.dataclass()(cls)
43-
flatten = lambda obj: jax.tree_flatten(dataclasses.asdict(obj))
43+
flatten = lambda obj: jax.tree_util.tree_flatten(dataclasses.asdict(obj))
4444
unflatten = lambda d, children: cls(**d.unflatten(children))
4545
jax.tree_util.register_pytree_node(cls, flatten, unflatten)
4646
return cls

tests/geometry/subsetting_test.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def geom_masked(request, pc_masked) -> Tuple[Geom_t, pointcloud.PointCloud]:
5959
@pytest.mark.fast()
6060
class TestMaskPointCloud:
6161

62-
@pytest.mark.parametrize("tgt_ixs", [7, jnp.arange(5)])
62+
@pytest.mark.parametrize("tgt_ixs", [[1], jnp.arange(5)])
6363
@pytest.mark.parametrize("src_ixs", [None, (3, 3)])
6464
@pytest.mark.parametrize(
6565
"clazz", [geometry.Geometry, pointcloud.PointCloud, low_rank.LRCGeometry]
@@ -78,12 +78,8 @@ def test_mask(
7878
geom = clazz(cost_matrix=x @ y.T, scale_cost="mean")
7979
else:
8080
geom = clazz(x, y, scale_cost="max_cost", batch_size=5)
81-
n = geom.shape[0] if src_ixs is None else 1 if isinstance(
82-
src_ixs, int
83-
) else len(src_ixs)
84-
m = geom.shape[1] if tgt_ixs is None else 1 if isinstance(
85-
tgt_ixs, int
86-
) else len(tgt_ixs)
81+
n = geom.shape[0] if src_ixs is None else len(src_ixs)
82+
m = geom.shape[1] if tgt_ixs is None else len(tgt_ixs)
8783

8884
if clazz is geometry.Geometry:
8985
geom_sub = geom.subset(src_ixs, tgt_ixs)

0 commit comments

Comments
 (0)