From 6f9b2dff336596383e82c4ef456386f74a8f760d Mon Sep 17 00:00:00 2001 From: Clement Date: Mon, 28 Jul 2025 21:41:23 +0200 Subject: [PATCH 01/19] 1st try potentials OT 1d --- ot/__init__.py | 2 + ot/lp/__init__.py | 2 + ot/lp/solver_1d.py | 126 ++++++++++++++++++++++++++++++++++++++++- test/test_1d_solver.py | 34 +++++++++-- 4 files changed, 157 insertions(+), 7 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 5e21d6a76..f675cb8f1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -44,6 +44,7 @@ emd2, emd_1d, emd2_1d, + emd_1d_dual, wasserstein_1d, binary_search_circle, wasserstein_circle, @@ -91,6 +92,7 @@ "toq", "gromov", "emd2_1d", + "emd_1d_dual", "wasserstein_1d", "backend", "gaussian", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 932b261df..58cc04b6d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -23,6 +23,7 @@ emd_1d, emd2_1d, wasserstein_1d, + emd_1d_dual, binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -38,6 +39,7 @@ "emd_1d", "emd2_1d", "wasserstein_1d", + "emd_1d_dual", "generalized_free_support_barycenter", "binary_search_circle", "wasserstein_circle", diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index c308549f8..ee8328529 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -16,7 +16,7 @@ from ..utils import list_to_array -def quantile_function(qs, cws, xs): +def quantile_function(qs, cws, xs, return_index=False): r"""Computes the quantile function of an empirical distribution Parameters @@ -27,6 +27,7 @@ def quantile_function(qs, cws, xs): cumulative weights of the 1D empirical distribution, if batched, must be similar to xs xs: array-like, shape (n, ...) locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + return_index: bool Returns ------- @@ -43,8 +44,14 @@ def quantile_function(qs, cws, xs): else: cws = cws.T qs = qs.T - idx = nx.searchsorted(cws, qs).T - return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + # idx = nx.searchsorted(cws, qs).T + # return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + idx = nx.clip(nx.searchsorted(cws, qs).T, 0, n - 1) + if return_index: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0), idx + else: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) def wasserstein_1d( @@ -399,6 +406,119 @@ def emd2_1d( return cost +def emd_1d_dual( + u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True +): + r""" + TODO + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential + loss: float/array-like, shape (...) + the batched EMD + """ + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + # Init weights or broadcast if necessary + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + # Sort w.r.t. support if not already done + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + # eps trick to have strictly increasing cdf and avoid zero mass issues + eps = 1e-12 + u_cdf = nx.cumsum(u_weights + eps, 0) - eps + v_cdf = nx.cumsum(v_weights + eps, 0) - eps + + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf), 0), 0) + + u_icdf, u_index = quantile_function(cdf_axis, u_cdf, u_values, return_index=True) + v_icdf, v_index = quantile_function(cdf_axis, v_cdf, v_values, return_index=True) + + diff_dist = nx.power(nx.abs(u_icdf - v_icdf), p) + cdf_axis = nx.zero_pad( + cdf_axis, pad_width=[(1, 0)] + (cdf_axis.ndim - 1) * [(0, 0)] + ) + + # delta = cdf_axis[1:, ...] - cdf_axis[:-1, ...] + # print(delta.dtype) + # print("?", diff_dist) + # # print("!!", nx.sum(delta * diff_dist, axis=0)) + + # parallel North-West corner rule (?) + mask_u = u_index[1:, ...] - u_index[:-1, ...] + mask_u = nx.zero_pad(mask_u, pad_width=[(1, 0)] + (mask_u.ndim - 1) * [(0, 0)]) + mask_v = v_index[1:, ...] - v_index[:-1, ...] + mask_v = nx.zero_pad(mask_v, pad_width=[(1, 0)] + (mask_v.ndim - 1) * [(0, 0)]) + + c1 = nx.where((mask_u[:-1, ...] + mask_u[1:, ...]) > 1, -1, 0) + c1 = nx.cumsum(c1 * diff_dist[:-1, ...], axis=0) + c1 = nx.zero_pad(c1, pad_width=[(1, 0)] + (c1.ndim - 1) * [(0, 0)]) + + c2 = nx.where((mask_v[:-1, ...] + mask_v[1:, ...]) > 1, -1, 0) + c2 = nx.cumsum(c2 * diff_dist[:-1, ...], axis=0) + c2 = nx.zero_pad(c2, pad_width=[(1, 0)] + (c2.ndim - 1) * [(0, 0)]) + + masked_u_dist = mask_u * diff_dist + masked_v_dist = mask_v * diff_dist + + T = nx.cumsum(masked_u_dist - masked_v_dist, axis=0) + c1 - c2 + + tmp = nx.copy(mask_u > 0) # avoid in-place problem + tmp[0, ...] = 1 + f = nx.reshape(T[tmp], u_values.shape) + f[0, ...] = 0 + + tmp = nx.copy(mask_v > 0) # avoid in-place problem + tmp[0, ...] = 1 + g = -nx.reshape(T[tmp], v_values.shape) + + loss = nx.sum(f * u_weights) + nx.sum(g * v_weights) + return f, g, loss + + def roll_cols(M, shifts): r""" Utils functions which allow to shift the order of each row of a 2d matrix diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 7ab1009af..db7a88085 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -94,7 +94,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -178,7 +178,7 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -218,6 +218,32 @@ def test_emd1d_device_tf(): assert nx.dtype_device(emd)[1].startswith("GPU") +def test_emd_dual_with_weights(): + # test emd1d_dual gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0.0, 1.0, n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0.0, 1.0, m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric="sqeuclidean") + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + + f, g, wass1d = ot.emd_1d_dual(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, np.sum(f * w_u) + np.sum(g * w_v)) + + def test_wasserstein_1d_circle(): # test binary_search_circle and wasserstein_circle give similar results as emd n = 20 @@ -267,7 +293,7 @@ def test_wasserstein1d_circle_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -317,7 +343,7 @@ def test_wasserstein1d_unif_circle_devices(nx): rho_u /= rho_u.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) From cb5660ddb2b530251a8413c39387b9854a492b8c Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 29 Jul 2025 15:44:05 +0200 Subject: [PATCH 02/19] emd1d_dual ok without batch --- ot/lp/solver_1d.py | 24 ++++++++++++++---------- test/test_1d_solver.py | 39 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index ee8328529..ab4595ddf 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -482,12 +482,7 @@ def emd_1d_dual( cdf_axis, pad_width=[(1, 0)] + (cdf_axis.ndim - 1) * [(0, 0)] ) - # delta = cdf_axis[1:, ...] - cdf_axis[:-1, ...] - # print(delta.dtype) - # print("?", diff_dist) - # # print("!!", nx.sum(delta * diff_dist, axis=0)) - - # parallel North-West corner rule (?) + # parallel North-West corner rule mask_u = u_index[1:, ...] - u_index[:-1, ...] mask_u = nx.zero_pad(mask_u, pad_width=[(1, 0)] + (mask_u.ndim - 1) * [(0, 0)]) mask_v = v_index[1:, ...] - v_index[:-1, ...] @@ -511,11 +506,20 @@ def emd_1d_dual( f = nx.reshape(T[tmp], u_values.shape) f[0, ...] = 0 - tmp = nx.copy(mask_v > 0) # avoid in-place problem - tmp[0, ...] = 1 - g = -nx.reshape(T[tmp], v_values.shape) + # Complementary slackness + C = nx.power(nx.abs(u_values[:, None] - v_values[None]), p) - f[:, None] + g = nx.min(C, axis=0) + + loss = nx.sum(f * u_weights, axis=0) + nx.sum(g * v_weights, axis=0) + + # unsort potentials + if require_sort: + u_rev_sorter = nx.argsort(u_sorter, 0) + f = nx.take_along_axis(f, u_rev_sorter, 0) + + v_rev_sorter = nx.argsort(v_sorter, 0) + g = nx.take_along_axis(g, v_rev_sorter, 0) - loss = nx.sum(f * u_weights) + nx.sum(g * v_weights) return f, g, loss diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index db7a88085..f01048cb3 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,7 +218,7 @@ def test_emd1d_device_tf(): assert nx.dtype_device(emd)[1].startswith("GPU") -def test_emd_dual_with_weights(): +def test_emd1d_dual_with_weights(): # test emd1d_dual gives similar results as emd n = 20 m = 30 @@ -241,7 +241,42 @@ def test_emd_dual_with_weights(): # check loss is similar np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, np.sum(f * w_u) + np.sum(g * w_v)) + np.testing.assert_allclose(wass, np.sum(f[:, 0] * w_u) + np.sum(g[:, 0] * w_v)) + + +def test_emd1d_dual_batch(nx): + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) + + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + f, g, res = ot.emd_1d_dual(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +def test_emd1d_dual_type_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + f, g, res = ot.emd_1d_dual(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) def test_wasserstein_1d_circle(): From 0a9d38b105c2d688f050d11d7be0777a9c55e6fd Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 29 Jul 2025 17:25:12 +0200 Subject: [PATCH 03/19] batched emd1d_dual --- ot/backend.py | 43 ++++++++++++++++++++++++++++++++++++++++++ ot/lp/solver_1d.py | 11 ++++++++++- test/test_1d_solver.py | 1 + 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 3d59639fa..3f0cb4189 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1081,6 +1081,20 @@ def slogdet(self, a): """ raise NotImplementedError() + def index_select(self, input, axis, index): + r""" + TODO + + See: https://docs.pytorch.org/docs/stable/generated/torch.index_select.html + """ + + def nonzero(self, input, as_tuple=False): + r""" + TODO + + See: https://docs.pytorch.org/docs/stable/generated/torch.nonzero.html + """ + class NumpyBackend(Backend): """ @@ -1444,6 +1458,16 @@ def det(self, a): def slogdet(self, a): return np.linalg.slogdet(a) + def index_select(self, input, axis, index): + return np.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return np.nonzero(input) + else: # TOCHECK + L_tuple = np.nonzero(input) + return np.concatenate([t[None] for t in L_tuple], axis=0) + _register_backend_implementation(NumpyBackend) @@ -1840,6 +1864,16 @@ def det(self, x): def slogdet(self, a): return jnp.linalg.slogdet(a) + def index_select(self, input, axis, index): + return jnp.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return jnp.nonzero(input) + else: # TOCHECK + L_tuple = jnp.nonzero(input) + return jnp.concatenate([t[None] for t in L_tuple], axis=0) + if jax: # Only register jax backend if it is installed @@ -2376,6 +2410,12 @@ def det(self, x): def slogdet(self, a): return torch.linalg.slogdet(a) + def index_select(self, input, axis, index): + return torch.index_select(input, axis, index) + + def nonzero(self, input, as_tuple=False): + return torch.nonzero(input, as_tuple=as_tuple) + if torch: # Only register torch backend if it is installed @@ -2787,6 +2827,9 @@ def det(self, x): def slogdet(self, a): return cp.linalg.slogdet(a) + def index_select(self, input, axis, index): + return cp.take(input, index, axis) + if cp: # Only register cp backend if it is installed diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index ab4595ddf..246951f50 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -503,7 +503,16 @@ def emd_1d_dual( tmp = nx.copy(mask_u > 0) # avoid in-place problem tmp[0, ...] = 1 - f = nx.reshape(T[tmp], u_values.shape) + # f = nx.reshape(T[tmp], u_values.shape) # work only with one axis + f = nx.reshape( + nx.index_select( + nx.reshape(T.T, (-1,)), + 0, + # nx.reshape(tmp.T, (-1,)).nonzero().squeeze() + nx.nonzero(nx.reshape(tmp.T, (-1,))).squeeze(), + ), + u_values.T.shape, + ).T f[0, ...] = 0 # Complementary slackness diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index f01048cb3..8f363a221 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -244,6 +244,7 @@ def test_emd1d_dual_with_weights(): np.testing.assert_allclose(wass, np.sum(f[:, 0] * w_u) + np.sum(g[:, 0] * w_v)) +@pytest.skip_backend("jax") def test_emd1d_dual_batch(nx): rng = np.random.RandomState(0) From 0be92a71fdd31a11638221ffe2a4fb1d2682051b Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 6 Aug 2025 18:43:22 +0200 Subject: [PATCH 04/19] 1d potentials with backprop, 1d uot 1st try --- ot/__init__.py | 2 + ot/lp/__init__.py | 2 + ot/lp/solver_1d.py | 106 ++++++++++++++++++++- ot/unbalanced/__init__.py | 3 + ot/unbalanced/_solver_1d.py | 182 ++++++++++++++++++++++++++++++++++++ test/test_1d_solver.py | 32 +++++++ 6 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 ot/unbalanced/_solver_1d.py diff --git a/ot/__init__.py b/ot/__init__.py index f675cb8f1..f0d554b37 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -45,6 +45,7 @@ emd_1d, emd2_1d, emd_1d_dual, + emd_1d_dual_backprop, wasserstein_1d, binary_search_circle, wasserstein_circle, @@ -93,6 +94,7 @@ "gromov", "emd2_1d", "emd_1d_dual", + "emd_1d_dual_backprop", "wasserstein_1d", "backend", "gaussian", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 58cc04b6d..09bdd3777 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -24,6 +24,7 @@ emd2_1d, wasserstein_1d, emd_1d_dual, + emd_1d_dual_backprop, binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -40,6 +41,7 @@ "emd2_1d", "wasserstein_1d", "emd_1d_dual", + "emd_1d_dual_backprop", "generalized_free_support_barycenter", "binary_search_circle", "wasserstein_circle", diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 246951f50..bf2dfd2d4 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -410,7 +410,18 @@ def emd_1d_dual( u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True ): r""" - TODO + Computes the 1 dimensional OT loss between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq + + and returns the dual potentials and the loss, i.e. such that + + .. math: + OT_{loss}(u,v) = \int f(x)\mathrm{d}u(x) + \int g(y)\mathrm{d}v(y). + + We do so by solving the dual problem using a parallel North-West corner rule. Parameters ---------- @@ -532,6 +543,99 @@ def emd_1d_dual( return f, g, loss +def emd_1d_dual_backprop( + u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True +): + r""" + Computes the 1 dimensional OT loss between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq + + and returns the dual potentials and the loss, i.e. such that + + .. math: + OT_{loss}(u,v) = \int f(x)\mathrm{d}u(x) + \int g(y)\mathrm{d}v(y). + + We do so by backpropagating through the `wasserstein_1d` function. Thus, the function + only works in torch and jax. + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential + loss: float/array-like, shape (...) + the batched EMD + """ + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" + + n = u_values.shape[0] + m = v_values.shape[0] + + # Init weights or broadcast if necessary + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if nx.__name__ == "torch": + u_weights.requires_grad_(True) + v_weights.requires_grad_(True) + cost_output = wasserstein_1d( + u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort + ) + loss = cost_output.sum() + loss.backward() + + return ( + u_weights.grad, + v_weights.grad, + cost_output.detach(), + ) # value can not be backward anymore + elif nx.__name__ == "jax": + import jax + + def ot_1d(a, b): + return wasserstein_1d( + u_values, v_values, a, b, p=p, require_sort=require_sort + ).sum() + + f, g = jax.grad(ot_1d, argnums=[0, 1])(u_weights, v_weights) + cost_output = wasserstein_1d( + u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort + ) + return f, g, cost_output + + def roll_cols(M, shifts): r""" Utils functions which allow to shift the order of each row of a 2d matrix diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 771452954..06423008d 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -24,6 +24,8 @@ from ._lbfgs import lbfgsb_unbalanced, lbfgsb_unbalanced2 +from ._solver_1d import uot_1d + __all__ = [ "sinkhorn_knopp_unbalanced", "sinkhorn_unbalanced", @@ -38,4 +40,5 @@ "_get_loss_unbalanced", "lbfgsb_unbalanced", "lbfgsb_unbalanced2", + "uot_1d", ] diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py new file mode 100644 index 000000000..4f7ffb939 --- /dev/null +++ b/ot/unbalanced/_solver_1d.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +""" +1D Unbalanced OT solvers +""" + +# Author: +# +# License: MIT License + +from ..backend import get_backend +from ..utils import get_parameter_pair +from ..lp.solver_1d import emd_1d_dual, emd_1d_dual_backprop + + +def rescale_potentials(f, g, a, b, rho1, rho2, nx): + r""" + TODO + """ + tau = (rho1 * rho2) / (rho1 + rho2) + num = nx.logsumexp(-f / rho1 + nx.log(a)) + denom = nx.logsumexp(-g / rho2 + nx.log(b)) + transl = tau * (num - denom) + return transl + + +def uot_1d( + u_values, + v_values, + reg_m, + u_weights=None, + v_weights=None, + p=1, + require_sort=True, + numItermax=1000, + stopThr=1e-6, + log=False, + mode="icdf", +): + r""" + TODO, TOTEST, seems not very stable? + + Solves the 1D unbalanced OT problem with KL regularization. + The function implements the Frank-Wolfe algorithm to solve the dual problem, + as proposed in [73]. + + TODO: add math equation + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + numItermax: int, optional + log: bool, optional + mode: str, optional + "icdf" for inverse CDF, "backprop" for backpropagation mode. + Default is "icdf". + + Returns + ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential + loss: float/array-like, shape (...) + the batched EMD + + References + --------- + .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + assert mode in ["backprop", "icdf"] + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + n = u_values.shape[0] + m = v_values.shape[0] + + # Init weights or broadcast if necessary + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + # Sort w.r.t. support if not already done + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_rev_sorter = nx.argsort(u_sorter, 0) + u_values_sorted = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_rev_sorter = nx.argsort(v_sorter, 0) + v_values_sorted = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights_sorted = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights_sorted = nx.take_along_axis(v_weights, v_sorter, 0) + + f = nx.zeros(u_weights.shape, type_as=u_weights) + g = nx.zeros(v_weights.shape, type_as=v_weights) + + for i in range(numItermax): + transl = rescale_potentials( + f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx + ) + + f = f + transl + g = g - transl + + u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1) + v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2) + + if mode == "icdf": + fd, gd, loss = emd_1d_dual( + u_values_sorted, + v_values_sorted, + u_weights=u_reweighted, + v_weights=v_reweighted, + p=p, + require_sort=False, + ) + elif mode == "backprop": + fd, gd, loss = emd_1d_dual_backprop( + u_values_sorted, + v_values_sorted, + u_weights=u_reweighted, + v_weights=v_reweighted, + p=p, + require_sort=False, + ) + + t = 2.0 / (2.0 + i) + f = f + t * (fd - f) + g = g + t * (gd - g) + + if require_sort: + f = nx.take_along_axis(f, u_rev_sorter, 0) + g = nx.take_along_axis(g, v_rev_sorter, 0) + u_reweighted = nx.take_along_axis(u_reweighted, u_rev_sorter, 0) + v_reweighted = nx.take_along_axis(v_reweighted, v_rev_sorter, 0) + + uot_loss = ( + loss + + reg_m1 * nx.kl_div(u_reweighted, u_weights) + + reg_m2 * nx.kl_div(v_reweighted, v_weights) + ) + + if log: + dico = {"f": f, "g": g} + return u_reweighted, v_reweighted, uot_loss, dico + return u_reweighted, v_reweighted, uot_loss diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 8f363a221..53d9ef0eb 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -263,6 +263,30 @@ def test_emd1d_dual_batch(nx): np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) +def test_emd1d_dual_backprop_batch(nx): + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) + + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + + if nx.__name__ in ["torch", "jax"]: + f, g, res = ot.emd_1d_dual_backprop(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + else: + np.testing.assert_raises( + AssertionError, ot.emd_1d_dual_backprop, Xb, Xb, rho_ub, rho_vb, p=2 + ) + + def test_emd1d_dual_type_devices(nx): rng = np.random.RandomState(0) @@ -278,6 +302,14 @@ def test_emd1d_dual_type_devices(nx): xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) f, g, res = ot.emd_1d_dual(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) + + if nx.__name__ == "torch" or nx.__name__ == "jax": + f, g, res = ot.emd_1d_dual_backprop(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) def test_wasserstein_1d_circle(): From cade9d58099987a88db3376df05e3c085ccfaff0 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 6 Aug 2025 21:47:34 +0200 Subject: [PATCH 05/19] up tests 1d solvers --- test/test_1d_solver.py | 5 ++++ test/unbalanced/test_1d_solver.py | 39 +++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 test/unbalanced/test_1d_solver.py diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 53d9ef0eb..251f26310 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -281,6 +281,11 @@ def test_emd1d_dual_backprop_batch(nx): if nx.__name__ in ["torch", "jax"]: f, g, res = ot.emd_1d_dual_backprop(Xb, Xb, rho_ub, rho_vb, p=2) np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + cost_dual = nx.sum(f * rho_ub[:, None], axis=0) + nx.sum( + g * rho_vb[:, None], axis=0 + ) + np.testing.assert_allclose(cost_dual, res) else: np.testing.assert_raises( AssertionError, ot.emd_1d_dual_backprop, Xb, Xb, rho_ub, rho_vb, p=2 diff --git a/test/unbalanced/test_1d_solver.py b/test/unbalanced/test_1d_solver.py new file mode 100644 index 000000000..595eba6aa --- /dev/null +++ b/test/unbalanced/test_1d_solver.py @@ -0,0 +1,39 @@ +"""Tests for module 1D Unbalanced OT""" + +# Author: +# +# License: MIT License + +import itertools +import numpy as np +import ot +import pytest + + +def test_uot_1d(nx): + pass + + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = 1.0 + + M = ot.dist(xs, xt) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + loss_mm = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div="kl") + + print("??", loss_mm) + + f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m) + + print("???", loss_1d[0]) + + np.testing.assert_allclose(loss_1d, loss_mm) From b0550441e7b94e59fb550ac4a061ef71bfb374d0 Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 9 Aug 2025 17:05:52 +0200 Subject: [PATCH 06/19] file sliced uot --- ot/unbalanced/_sliced.py | 8 ++++++++ ot/unbalanced/_solver_1d.py | 6 +++--- 2 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 ot/unbalanced/_sliced.py diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py new file mode 100644 index 000000000..d1de5b684 --- /dev/null +++ b/ot/unbalanced/_sliced.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" +Sliced Unbalanced OT solvers +""" + +# Author: +# +# License: MIT License diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py index 4f7ffb939..00eb0dc20 100644 --- a/ot/unbalanced/_solver_1d.py +++ b/ot/unbalanced/_solver_1d.py @@ -31,10 +31,10 @@ def uot_1d( v_weights=None, p=1, require_sort=True, - numItermax=1000, + numItermax=10, stopThr=1e-6, - log=False, mode="icdf", + log=False, ): r""" TODO, TOTEST, seems not very stable? @@ -71,10 +71,10 @@ def uot_1d( sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True numItermax: int, optional - log: bool, optional mode: str, optional "icdf" for inverse CDF, "backprop" for backpropagation mode. Default is "icdf". + log: bool, optional Returns ------- From c4971a8e1babd1bd3a117a2cfa51f744887d571e Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 9 Aug 2025 18:36:10 +0200 Subject: [PATCH 07/19] clip max cdf in wasserstein_1d --- ot/backend.py | 12 +- ot/lp/solver_1d.py | 4 +- ot/unbalanced/_sliced.py | 218 ++++++++++++++++++++++++++++++ test/unbalanced/test_1d_solver.py | 7 +- test/unbalanced/test_sliced.py | 10 ++ 5 files changed, 240 insertions(+), 11 deletions(-) create mode 100644 test/unbalanced/test_sliced.py diff --git a/ot/backend.py b/ot/backend.py index 3f0cb4189..efd129838 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -569,7 +569,7 @@ def flip(self, a, axis=None): """ raise NotImplementedError() - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): """ Limits the values in a tensor. @@ -1233,7 +1233,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return np.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return np.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -1640,7 +1640,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return jnp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return jnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -2103,7 +2103,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return torch.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return torch.clamp(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -2577,7 +2577,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return cp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return cp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -3002,7 +3002,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return tnp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return tnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index bf2dfd2d4..0538a9b98 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -127,8 +127,8 @@ def wasserstein_1d( u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - u_cumweights = nx.cumsum(u_weights, 0) - v_cumweights = nx.cumsum(v_weights, 0) + u_cumweights = nx.clip(nx.cumsum(u_weights, 0), a_max=1) + v_cumweights = nx.clip(nx.cumsum(v_weights, 0), a_max=1) qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) u_quantiles = quantile_function(qs, u_cumweights, u_values) diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py index d1de5b684..008bf5e70 100644 --- a/ot/unbalanced/_sliced.py +++ b/ot/unbalanced/_sliced.py @@ -6,3 +6,221 @@ # Author: # # License: MIT License + +from ..backend import get_backend +from ..utils import get_parameter_pair, list_to_array +from ..sliced import get_random_projections +from ._solver_1d import rescale_potentials +from ..lp.solver_1d import emd_1d_dual, emd_1d_dual_backprop, wasserstein_1d + + +def unbalanced_sliced_ot_pot( + X_s, + X_t, + reg_m, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + numItermax=10, + mode="backprop", + stochastic_proj=False, + log=False, +): + r""" + Compute USOT + + TODO + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional, by default =2 + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + numItermax: int, optional + mode: str, optional + "icdf" for inverse CDF, "backprop" for backpropagation mode. + Default is "icdf". + stochastic_proj: bool, default False + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential + loss: float/array-like, shape (...) + the batched EMD + + References + ---------- + [] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research + """ + assert mode in ["backprop", "icdf"] + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None and not stochastic_proj: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + else: + n_projections = projections.shape[1] + + if not stochastic_proj: + X_s_projections = nx.dot(X_s, projections).T # shape (n_projs, n) + X_t_projections = nx.dot(X_t, projections).T + + X_s_sorter = nx.argsort(X_s_projections, -1) + X_s_rev_sorter = nx.argsort(X_s_sorter, -1) + X_s_sorted = nx.take_along_axis(X_s_projections, X_s_sorter, -1) + + X_t_sorter = nx.argsort(X_t_projections, -1) + X_t_rev_sorter = nx.argsort(X_t_sorter, -1) + X_t_sorted = nx.take_along_axis(X_t_projections, X_t_sorter, -1) + + # Initialize potentials - WARNING: They correspond to non-sorted samples + f = nx.zeros(a.shape, type_as=a) + g = nx.zeros(b.shape, type_as=b) + + for i in range(numItermax): + # Output FW descent direction + # translate potentials + transl = rescale_potentials(f, g, a, b, reg_m1, reg_m2, nx) + + f = f + transl + g = g - transl + + # If stochastic version then sample new directions and re-sort data + # TODO: add functions to sample and project + if stochastic_proj: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) + + X_s_sorter = nx.argsort(X_s_projections, -1) + X_s_rev_sorter = nx.argsort(X_s_sorter, -1) + X_s_sorted = nx.take_along_axis(X_s_projections, X_s_sorter, -1) + + X_t_sorter = nx.argsort(X_t_projections, -1) + X_t_rev_sorter = nx.argsort(X_t_sorter, -1) + X_t_sorted = nx.take_along_axis(X_t_projections, X_t_sorter, -1) + + # update measures + a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] + b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] + + # solve for new potentials + if mode == "icdf": + fd, gd, loss = emd_1d_dual( + X_s_sorted.T, + X_t_sorted.T, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + fd, gd = fd.T, gd.T + + elif mode == "backprop": + fd, gd, loss = emd_1d_dual_backprop( + X_s_sorted.T, + X_t_sorted.T, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + fd, gd = fd.T, gd.T + + # default step for FW + t = 2.0 / (2.0 + i) + + f = f + t * (nx.mean(nx.take_along_axis(fd, X_s_rev_sorter, 1), axis=0) - f) + g = g + t * (nx.mean(nx.take_along_axis(gd, X_t_rev_sorter, 1), axis=0) - g) + + # Last iter before output + transl = rescale_potentials(f, g, a, b, reg_m1, reg_m2, nx) + f, g = f + transl, g - transl + + a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] + b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] + + loss = nx.mean( + wasserstein_1d( + X_s_sorted, + X_t_sorted, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + ) + a_reweighted, b_reweighted = a * nx.exp(-f / reg_m1), b * nx.exp(-g / reg_m2) + uot_loss = ( + loss + reg_m1 * nx.kl_div(a_reweighted, a) + reg_m2 * nx.kl_div(b_reweighted, b) + ) + + if log: + return a_reweighted, b_reweighted, uot_loss, {"projections": projections} + + return a_reweighted, b_reweighted, uot_loss diff --git a/test/unbalanced/test_1d_solver.py b/test/unbalanced/test_1d_solver.py index 595eba6aa..622f194c1 100644 --- a/test/unbalanced/test_1d_solver.py +++ b/test/unbalanced/test_1d_solver.py @@ -32,8 +32,9 @@ def test_uot_1d(nx): print("??", loss_mm) - f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m) + if nx.__name__ in ["jax", "torch"]: + f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="backprop") - print("???", loss_1d[0]) + print("???", loss_1d[0]) - np.testing.assert_allclose(loss_1d, loss_mm) + np.testing.assert_allclose(loss_1d, loss_mm) diff --git a/test/unbalanced/test_sliced.py b/test/unbalanced/test_sliced.py new file mode 100644 index 000000000..15a7a72b2 --- /dev/null +++ b/test/unbalanced/test_sliced.py @@ -0,0 +1,10 @@ +"""Tests for module sliced Unbalanced OT""" + +# Author: +# +# License: MIT License + +import itertools +import numpy as np +import ot +import pytest From 246d0929ae6d6653c8d4abf66ae631c802bb65a7 Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 9 Aug 2025 23:35:49 +0200 Subject: [PATCH 08/19] Example UOT 1d --- examples/unbalanced-partial/plot_UOT_1D.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index ade4bbb0c..37d51bed0 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -88,3 +88,24 @@ pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") + + +# %% +############################################################################## +# Solve Unbalanced UOT with Frank-Wolfe +# ------------------------- + +alpha = 1000.0 # Unbalanced KL relaxation parameter +f, g, loss = ot.unbalanced.uot_1d(x, x, a, b, alpha) + + +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, f, "b", alpha=0.5, label="Transported source") +pl.fill(x, g, "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") From b0c791ca88ae5f17a4f4724962e586000805b203 Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 10 Aug 2025 11:41:21 +0200 Subject: [PATCH 09/19] normalize weights --- examples/unbalanced-partial/plot_UOT_1D.py | 6 ++--- ot/lp/solver_1d.py | 4 +-- ot/unbalanced/_sliced.py | 29 ++++++++++++++-------- ot/unbalanced/_solver_1d.py | 7 ++++++ 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 37d51bed0..752e7b79f 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -96,7 +96,7 @@ # ------------------------- alpha = 1000.0 # Unbalanced KL relaxation parameter -f, g, loss = ot.unbalanced.uot_1d(x, x, a, b, alpha) +a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d(x, x, a, b, alpha) # plot the transported mass @@ -105,7 +105,7 @@ pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") pl.plot(x, b, "r", label="Target distribution") -pl.fill(x, f, "b", alpha=0.5, label="Transported source") -pl.fill(x, g, "r", alpha=0.5, label="Transported target") +pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 0538a9b98..bf2dfd2d4 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -127,8 +127,8 @@ def wasserstein_1d( u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - u_cumweights = nx.clip(nx.cumsum(u_weights, 0), a_max=1) - v_cumweights = nx.clip(nx.cumsum(v_weights, 0), a_max=1) + u_cumweights = nx.cumsum(u_weights, 0) + v_cumweights = nx.cumsum(v_weights, 0) qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) u_quantiles = quantile_function(qs, u_cumweights, u_values) diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py index 008bf5e70..54ecf8a51 100644 --- a/ot/unbalanced/_sliced.py +++ b/ot/unbalanced/_sliced.py @@ -169,6 +169,10 @@ def unbalanced_sliced_ot_pot( a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] + # normalize the weights for compatibility with wasserstein_1d + a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) + b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) + # solve for new potentials if mode == "icdf": fd, gd, loss = emd_1d_dual( @@ -205,19 +209,24 @@ def unbalanced_sliced_ot_pot( a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] - loss = nx.mean( - wasserstein_1d( - X_s_sorted, - X_t_sorted, - u_weights=a_reweighted.T, - v_weights=b_reweighted.T, - p=p, - require_sort=False, - ) + a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) + b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) + + ot_loss = wasserstein_1d( + X_s_sorted, + X_t_sorted, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, ) + sot_loss = nx.mean(ot_loss * nx.sum(a_reweighted, axis=1)) + a_reweighted, b_reweighted = a * nx.exp(-f / reg_m1), b * nx.exp(-g / reg_m2) uot_loss = ( - loss + reg_m1 * nx.kl_div(a_reweighted, a) + reg_m2 * nx.kl_div(b_reweighted, b) + sot_loss + + reg_m1 * nx.kl_div(a_reweighted, a) + + reg_m2 * nx.kl_div(b_reweighted, b) ) if log: diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py index 00eb0dc20..b2dd65545 100644 --- a/ot/unbalanced/_solver_1d.py +++ b/ot/unbalanced/_solver_1d.py @@ -141,6 +141,10 @@ def uot_1d( u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1) v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2) + # Normalize weights + u_reweighted = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True) + v_reweighted = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True) + if mode == "icdf": fd, gd, loss = emd_1d_dual( u_values_sorted, @@ -170,6 +174,9 @@ def uot_1d( u_reweighted = nx.take_along_axis(u_reweighted, u_rev_sorter, 0) v_reweighted = nx.take_along_axis(v_reweighted, v_rev_sorter, 0) + # rescale OT loss + loss = loss * nx.sum(u_reweighted, axis=0) + uot_loss = ( loss + reg_m1 * nx.kl_div(u_reweighted, u_weights) From f9dc43a455397f375f6fcdf0ccd47584339b4d99 Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 10 Aug 2025 15:59:30 +0200 Subject: [PATCH 10/19] add suot --- ot/unbalanced/_sliced.py | 136 ++++++++++++++++++++++++++++++++++-- ot/unbalanced/_solver_1d.py | 10 +-- 2 files changed, 134 insertions(+), 12 deletions(-) diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py index 54ecf8a51..b3d2f6343 100644 --- a/ot/unbalanced/_sliced.py +++ b/ot/unbalanced/_sliced.py @@ -10,11 +10,133 @@ from ..backend import get_backend from ..utils import get_parameter_pair, list_to_array from ..sliced import get_random_projections -from ._solver_1d import rescale_potentials +from ._solver_1d import rescale_potentials, uot_1d from ..lp.solver_1d import emd_1d_dual, emd_1d_dual_backprop, wasserstein_1d -def unbalanced_sliced_ot_pot( +def sliced_unbalanced_ot( + X_s, + X_t, + reg_m, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + numItermax=10, + mode="backprop", + log=False, +): + r""" + Compute SUOT + + TODO + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional, by default =2 + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + numItermax: int, optional + mode: str, optional + "icdf" for inverse CDF, "backprop" for backpropagation mode. + Default is "icdf". + log: bool, optional + if True, returns the projections used and their associated UOTs and reweighted marginals. + + Returns + ------- + loss: float/array-like, shape (...) + SUOT + + References + ---------- + [] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research + """ + assert mode in ["backprop", "icdf"] + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + else: + n_projections = projections.shape[1] + + X_s_projections = nx.dot(X_s, projections) # shape (n, n_projs) + X_t_projections = nx.dot(X_t, projections) + + a_reweighted, b_reweighted, projected_uot = uot_1d( + X_s_projections, X_t_projections, reg_m, a, b, p, require_sort=True, mode=mode + ) + + res = nx.mean(projected_uot) ** (1.0 / p) + + if log: + dico = { + "projection": projections, + "projected_uots": projected_uot, + "a_reweighted": a_reweighted, + "b_reweighted": b_reweighted, + } + return res, dico + + return res + + +def unbalanced_sliced_ot( X_s, X_t, reg_m, @@ -72,12 +194,12 @@ def unbalanced_sliced_ot_pot( Returns ------- - f: array-like shape (n, ...) - First dual potential - g: array-like shape (m, ...) - Second dual potential + a_reweighted: array-like shape (n, ...) + First marginal reweighted + b_reweighted: array-like shape (m, ...) + Second marginal reweighted loss: float/array-like, shape (...) - the batched EMD + USOT References ---------- diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py index b2dd65545..5cd85461f 100644 --- a/ot/unbalanced/_solver_1d.py +++ b/ot/unbalanced/_solver_1d.py @@ -78,12 +78,12 @@ def uot_1d( Returns ------- - f: array-like shape (n, ...) - First dual potential - g: array-like shape (m, ...) - Second dual potential + u_reweighted: array-like shape (n, ...) + First marginal reweighted + v_reweighted: array-like shape (m, ...) + Second marginal reweighted loss: float/array-like, shape (...) - the batched EMD + the batched 1D UOT References --------- From 6fe05ead12e66b5023cb7d8d19052c767f8680c0 Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 10 Aug 2025 16:46:45 +0200 Subject: [PATCH 11/19] add code example (to test) --- README.md | 4 +- RELEASES.md | 2 + .../unbalanced-partial/plot_UOT_sliced.py | 278 ++++++++++++++++++ ot/__init__.py | 8 +- ot/unbalanced/__init__.py | 4 + ot/unbalanced/_sliced.py | 15 +- ot/unbalanced/_solver_1d.py | 4 +- 7 files changed, 301 insertions(+), 14 deletions(-) create mode 100644 examples/unbalanced-partial/plot_UOT_sliced.py diff --git a/README.md b/README.md index 8b4cca7f7..c06e5900e 100644 --- a/README.md +++ b/README.md @@ -320,7 +320,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. @@ -389,3 +389,5 @@ Artificial Intelligence. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. + +[76] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 542f94851..8c33b8819 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -20,6 +20,8 @@ - Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731) - Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743) - Removed release information from quickstart guide (PR #744) +- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #) +- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/examples/unbalanced-partial/plot_UOT_sliced.py b/examples/unbalanced-partial/plot_UOT_sliced.py new file mode 100644 index 000000000..a7b0ab1ee --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_sliced.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +""" +=============================== +Sliced Unbalanced optimal transport +=============================== + +This example illustrates the behavior of Sliced UOT versus +Unbalanced Sliced OT. + +The first one removes outliers on each sliced while the second one +removes outliers of the original marginals. +""" + +# Author: +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch +import matplotlib.pyplot as plt +import matplotlib as mpl + +from sklearn.neighbors import KernelDensity + +############################################################################## +# Generate data +# ------------- + + +# %% parameters + +get_rot = lambda theta: np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] +) + + +# regular distribution of Gaussians around a circle +def make_blobs_reg(n_samples, n_blobs, scale=0.5): + per_blob = int(n_samples / n_blobs) + result = np.random.randn(per_blob, 2) * scale + 5 + theta = (2 * np.pi) / (n_blobs) + for r in range(1, n_blobs): + new_blob = (np.random.randn(per_blob, 2) * scale + 5).dot(get_rot(theta * r)) + result = np.vstack((result, new_blob)) + return result + + +def make_blobs_random(n_samples, n_blobs, scale=0.5, offset=3): + per_blob = int(n_samples / n_blobs) + result = np.random.randn(per_blob, 2) * scale + np.random.randn(1, 2) * offset + for r in range(1, n_blobs): + new_blob = np.random.randn(per_blob, 2) * scale + np.random.randn(1, 2) * offset + result = np.vstack((result, new_blob)) + return result + + +def make_spiral(n_samples, noise=0.5): + n = np.sqrt(np.random.rand(n_samples, 1)) * 780 * (2 * np.pi) / 360 + d1x = -np.cos(n) * n + np.random.rand(n_samples, 1) * noise + d1y = np.sin(n) * n + np.random.rand(n_samples, 1) * noise + return np.array(np.hstack((d1x, d1y))) + + +n_samples = 500 +expe = "outlier" + +np.random.seed(42) + +nb_outliers = 200 +Xs = make_blobs_random(n_samples=n_samples, scale=0.2, n_blobs=1, offset=0) - 0.5 +Xs_outlier = make_blobs_random( + n_samples=nb_outliers, scale=0.05, n_blobs=1, offset=0 +) - [2, 0.5] + +Xs = np.vstack((Xs, Xs_outlier)) +Xt = make_blobs_random(n_samples=n_samples, scale=0.2, n_blobs=1, offset=0) + 1.5 +y = np.hstack(([0] * (n_samples + nb_outliers), [1] * n_samples)) +X = np.vstack((Xs, Xt)) + + +Xs_torch = torch.from_numpy(Xs).type(torch.float) +Xt_torch = torch.from_numpy(Xt).type(torch.float) + +p = 2 +num_proj = 180 + +a = torch.ones(Xs.shape[0], dtype=torch.float) +b = torch.ones(Xt.shape[0], dtype=torch.float) + +# construct projections +thetas = np.linspace(0, np.pi, num_proj) +dir = np.array([(np.cos(theta), np.sin(theta)) for theta in thetas]) +dir_torch = torch.from_numpy(dir).type(torch.float) + + +Xps = torch.dot(Xs_torch, dir_torch.T) # shape (n, n_projs) +Xpt = torch.dot(Xt_torch, dir_torch.T) + +############################################################################## +# Compute SUOT and USOT +# ------------- + +# %% + +rho1_SUOT = 1 +rho2_SUOT = 1 +_, log = ot.unbalanced.sliced_unbalanced_ot( + Xs_torch, + Xt_torch, + (rho1_SUOT, rho2_SUOT), + a, + b, + num_proj, + p, + numItermax=10, + projections=dir_torch.T, + mode="backprop", + log=True, +) +A_SUOT, B_SUOT = log["a_reweighted"].T, log["b_reweighted"].T + + +rho1_USOT = 1 +rho2_USOT = 1 +A_USOT, B_USOT, _ = ot.unbalanced_sliced_ot( + Xs_torch, + Xt_torch, + (rho1_USOT, rho2_USOT), + a, + b, + num_proj, + p, + numItermax=10, + projections=dir_torch.T, + mode="backprop", +) + + +############################################################################## +# Plot reweighted distributions on several slices +# ------------- + +# %% + + +def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): + """Kernel Density Estimation with Scikit-learn""" + kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs) + if weights is not None: + kde_skl.fit(x[:, np.newaxis], sample_weight=weights) + else: + kde_skl.fit(x[:, np.newaxis]) + # score_samples() returns the log-likelihood of the samples + log_pdf = kde_skl.score_samples(x_grid[:, np.newaxis]) + return np.exp(log_pdf) + + +c1 = np.array(mpl.colors.to_rgb("lightcoral")) +c2 = np.array(mpl.colors.to_rgb("steelblue")) + +# define plotting grid +xlim_min = -3 +xlim_max = 3 +x_grid = np.linspace(xlim_min, xlim_max, 200) +bw = 0.05 + +# visu parameters +nb_slices = 6 +offset_degree = int(180 / nb_slices) + +delta_degree = np.pi / nb_slices +colors = plt.cm.Reds(np.linspace(0.3, 1, nb_slices)) + +X1 = np.array([-4, 0]) +X2 = np.array([4, 0]) + +fig = plt.figure(figsize=(28, 8)) +ax1 = plt.subplot2grid((nb_slices, 3), (0, 0), rowspan=nb_slices) + + +for i in range(nb_slices): + R = get_rot(delta_degree * (-i)) + X1_r = X1.dot(R) + X2_r = X2.dot(R) + if i == 0: + ax1.plot( + [X1_r[0], X2_r[0]], + [X1_r[1], X2_r[1]], + color=colors[i], + alpha=0.8, + zorder=0, + label="Directions", + ) + else: + ax1.plot( + [X1_r[0], X2_r[0]], [X1_r[1], X2_r[1]], color=colors[i], alpha=0.8, zorder=0 + ) +ax1.scatter(Xs[:, 0], Xs[:, 1], zorder=1, color=c2, label="Source data") +ax1.scatter(Xt[:, 0], Xt[:, 1], zorder=1, color=c1, label="Target data") +ax1.set_xlim([-3, 3]) +ax1.set_ylim([-3, 3]) +ax1.set_yticks([]) +ax1.set_xticks([]) +ax1.legend(loc="best", fontsize=18) +ax1.set_xlabel("Original distributions", fontsize=22) + +# ***** plot SUOT +fig.subplots_adjust(hspace=0) +fig.subplots_adjust(wspace=0.1) + +for i in range(nb_slices): + ax = plt.subplot2grid((nb_slices, 3), (i, 1)) + weights_src = A_SUOT[i * offset_degree, :].cpu().numpy() + weights_tgt = B_SUOT[i * offset_degree, :].cpu().numpy() + samples_src = Xps[i * offset_degree, :].cpu().numpy() + samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() + pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) + pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) + pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) + pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) + + ax.scatter(samples_src, [-0.2] * samples_src.shape[0], color=c2, s=2) + ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) + + ax.scatter(samples_tgt, [-0.2] * samples_tgt.shape[0], color=c1, s=2) + ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) + + # frac_mass = int(100*weights_src.sum()) + # plt.text(.9, .9, '% mass={}%'.format(frac_mass), ha='right', va='top', color='red',fontsize=14, transform=ax.transAxes) + + ax.set_xlim(xlim_min, xlim_max) + ax.set_ylabel( + r"$\theta=${}$^o$".format(i * offset_degree), color=colors[i], fontsize=16 + ) + ax.set_yticks([]) + ax.set_yticks([]) +ax.set_xlabel( + r"SUOT $\rho_1={}$ $\rho_2={}$".format(rho1_SUOT, rho2_SUOT), fontsize=22 +) +# ***** plot USOT + +for i in range(nb_slices): + ax = plt.subplot2grid((nb_slices, 3), (i, 2)) + weights_src = A_USOT.cpu().numpy() + weights_tgt = B_USOT.cpu().numpy() + samples_src = Xps[i * offset_degree, :].cpu().numpy() + samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() + pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) + pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) + pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) + pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) + + ax.scatter(samples_src, [-0.2] * samples_src.shape[0], color=c2, s=2) + ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) + + ax.scatter(samples_tgt, [-0.2] * samples_tgt.shape[0], color=c1, s=2) + ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) + + ax.set_xlim(xlim_min, xlim_max) + ax.set_ylabel( + r"$\theta=${}$^o$".format(i * offset_degree), color=colors[i], fontsize=16 + ) + ax.set_yticks([]) +ax.set_xlabel( + r"USOT $\rho_1={}$ $\rho_2={}$".format(rho1_USOT, rho2_USOT), fontsize=22 +) + +plt.show() diff --git a/ot/__init__.py b/ot/__init__.py index f0d554b37..43f8e05dc 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -52,7 +52,12 @@ semidiscrete_wasserstein2_unif_circle, ) from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 +from .unbalanced import ( + sinkhorn_unbalanced, + barycenter_unbalanced, + sinkhorn_unbalanced2, + unbalanced_sliced_ot, +) from .da import sinkhorn_lpl1_mm from .sliced import ( sliced_wasserstein_distance, @@ -109,6 +114,7 @@ "sinkhorn_unbalanced2", "sliced_wasserstein_distance", "sliced_wasserstein_sphere", + "unbalanced_sliced_ot", "gromov_wasserstein", "gromov_wasserstein2", "gromov_barycenters", diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 06423008d..b7a526182 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -26,6 +26,8 @@ from ._solver_1d import uot_1d +from ._sliced import sliced_unbalanced_ot, unbalanced_sliced_ot + __all__ = [ "sinkhorn_knopp_unbalanced", "sinkhorn_unbalanced", @@ -41,4 +43,6 @@ "lbfgsb_unbalanced", "lbfgsb_unbalanced2", "uot_1d", + "sliced_unbalanced_ot", + "unbalanced_sliced_ot", ] diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py index b3d2f6343..c26dcd4a3 100644 --- a/ot/unbalanced/_sliced.py +++ b/ot/unbalanced/_sliced.py @@ -291,6 +291,8 @@ def unbalanced_sliced_ot( a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] + full_mass = nx.sum(a_reweighted, axis=1) + # normalize the weights for compatibility with wasserstein_1d a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) @@ -324,16 +326,6 @@ def unbalanced_sliced_ot( f = f + t * (nx.mean(nx.take_along_axis(fd, X_s_rev_sorter, 1), axis=0) - f) g = g + t * (nx.mean(nx.take_along_axis(gd, X_t_rev_sorter, 1), axis=0) - g) - # Last iter before output - transl = rescale_potentials(f, g, a, b, reg_m1, reg_m2, nx) - f, g = f + transl, g - transl - - a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] - b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] - - a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) - b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) - ot_loss = wasserstein_1d( X_s_sorted, X_t_sorted, @@ -342,9 +334,10 @@ def unbalanced_sliced_ot( p=p, require_sort=False, ) - sot_loss = nx.mean(ot_loss * nx.sum(a_reweighted, axis=1)) + sot_loss = nx.mean(ot_loss * full_mass) a_reweighted, b_reweighted = a * nx.exp(-f / reg_m1), b * nx.exp(-g / reg_m2) + uot_loss = ( sot_loss + reg_m1 * nx.kl_div(a_reweighted, a) diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py index 5cd85461f..5d721d750 100644 --- a/ot/unbalanced/_solver_1d.py +++ b/ot/unbalanced/_solver_1d.py @@ -141,6 +141,8 @@ def uot_1d( u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1) v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2) + full_mass = nx.sum(u_reweighted, axis=0) + # Normalize weights u_reweighted = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True) v_reweighted = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True) @@ -175,7 +177,7 @@ def uot_1d( v_reweighted = nx.take_along_axis(v_reweighted, v_rev_sorter, 0) # rescale OT loss - loss = loss * nx.sum(u_reweighted, axis=0) + loss = loss * full_mass uot_loss = ( loss From c361c32289d3e377866be23387149cd4b7709dbb Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 10 Aug 2025 17:54:15 +0200 Subject: [PATCH 12/19] tests backend --- ot/backend.py | 14 ++++++++------ test/test_backend.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index efd129838..4448703df 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1083,17 +1083,19 @@ def slogdet(self, a): def index_select(self, input, axis, index): r""" - TODO + Returns a new tensor which indexes the input tensor along dimension dim using the entries in index. See: https://docs.pytorch.org/docs/stable/generated/torch.index_select.html """ + raise NotImplementedError() def nonzero(self, input, as_tuple=False): r""" - TODO + Returns a tensor containing the indices of all non-zero elements of input. See: https://docs.pytorch.org/docs/stable/generated/torch.nonzero.html """ + raise NotImplementedError() class NumpyBackend(Backend): @@ -1464,9 +1466,9 @@ def index_select(self, input, axis, index): def nonzero(self, input, as_tuple=False): if as_tuple: return np.nonzero(input) - else: # TOCHECK + else: L_tuple = np.nonzero(input) - return np.concatenate([t[None] for t in L_tuple], axis=0) + return np.concatenate([t[None] for t in L_tuple], axis=0).T _register_backend_implementation(NumpyBackend) @@ -1870,9 +1872,9 @@ def index_select(self, input, axis, index): def nonzero(self, input, as_tuple=False): if as_tuple: return jnp.nonzero(input) - else: # TOCHECK + else: L_tuple = jnp.nonzero(input) - return jnp.concatenate([t[None] for t in L_tuple], axis=0) + return jnp.concatenate([t[None] for t in L_tuple], axis=0).T if jax: diff --git a/test/test_backend.py b/test/test_backend.py index ff5685f6a..c110c93fa 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -97,6 +97,7 @@ def test_empty_backend(): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) v = rnd.randn(3) + inds = rnd.randint(10) nx = ot.backend.Backend() @@ -273,6 +274,10 @@ def test_empty_backend(): nx.det(M) with pytest.raises(NotImplementedError): nx.slogdet(M) + with pytest.raises(NotImplementedError): + nx.index_select(M, 0, inds) + with pytest.raises(NotImplementedError): + nx.nonzero(M) def test_func_backends(nx): @@ -702,6 +707,14 @@ def test_func_backends(nx): lst_b.append(np.array([s, logabsd])) lst_name.append("slogdet") + vec = nx.index_select(vb, 0, nx.from_numpy(np.array([0, 1]))) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("index_select") + + vec = nx.nonzero(Mb) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("nonzero") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( From c08655cae058f45596b0d300b4341c679dffb3b0 Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 22 Aug 2025 23:01:25 +0200 Subject: [PATCH 13/19] up code example 1D UOT --- examples/unbalanced-partial/plot_UOT_1D.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 752e7b79f..2a6aedfe4 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -90,13 +90,15 @@ pl.title("Distributions and transported mass for UOT") -# %% ############################################################################## -# Solve Unbalanced UOT with Frank-Wolfe -# ------------------------- +# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) +# ----------------------------- alpha = 1000.0 # Unbalanced KL relaxation parameter -a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d(x, x, a, b, alpha) + +a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d( + x, x, alpha, u_weights=a, v_weights=b +) # plot the transported mass From 26473e13947019970d84003d451c4791edd7be88 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 27 Aug 2025 20:21:12 +0200 Subject: [PATCH 14/19] Examples UOT 1D --- examples/unbalanced-partial/plot_UOT_1D.py | 105 ++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 2a6aedfe4..126aa2c54 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -90,11 +90,27 @@ pl.title("Distributions and transported mass for UOT") +############################################################################## +# Solve Unbalanced OT +# ------------------------- + +alpha = 1.0 # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.mm_unbalanced(a, b, M, alpha, verbose=False) + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source") +pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") + + ############################################################################## # Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) # ----------------------------- -alpha = 1000.0 # Unbalanced KL relaxation parameter +alpha = 10000.0 # Unbalanced KL relaxation parameter a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d( x, x, alpha, u_weights=a, v_weights=b @@ -111,3 +127,90 @@ pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") + + +############################################################################## +# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) +# ----------------------------- +import torch + +alpha = 10000.0 # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( + torch.tensor(x.reshape((n, 1)), dtype=torch.float64), + torch.tensor(x.reshape((n, 1)), dtype=torch.float64), + alpha, + torch.tensor(a, dtype=torch.float64), + torch.tensor(b, dtype=torch.float64), + mode="backprop", +) + + +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") + + +############################################################################## +# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) +# ----------------------------- +import torch + +alpha = 10000.0 # (10000, 10000) # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( + torch.tensor(x.reshape((n, 1)), dtype=torch.float64), + torch.tensor(x.reshape((n, 1)), dtype=torch.float64), + alpha, + torch.tensor(a, dtype=torch.float64), + torch.tensor(b, dtype=torch.float64), + mode="backprop", +) + + +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted.detach().numpy(), "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted.detach().numpy(), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") + + +############################################################################## +# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) +# ----------------------------- +import torch + +alpha = 10000.0 # (10000, 10000) # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( + torch.tensor(x.reshape((n, 1)), dtype=torch.float32), + torch.tensor(x.reshape((n, 1)), dtype=torch.float32), + alpha, + torch.tensor(a, dtype=torch.float32), + torch.tensor(b, dtype=torch.float32), + mode="backprop", +) + + +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted.detach().numpy(), "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted.detach().numpy(), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") From 0ca65a6ae1c114f4a5c9d46acbd18e3f675572a3 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 28 Aug 2025 17:20:28 +0200 Subject: [PATCH 15/19] fix output loss uot_1d --- examples/unbalanced-partial/plot_UOT_1D.py | 100 +++++--------------- ot/lp/solver_1d.py | 4 +- ot/unbalanced/_sliced.py | 4 +- ot/unbalanced/_solver_1d.py | 103 +++++++++++++++------ 4 files changed, 100 insertions(+), 111 deletions(-) diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 126aa2c54..b2ba4f230 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -19,6 +19,7 @@ import ot import ot.plot from ot.datasets import make_1D_gauss as gauss +import torch ############################################################################## # Generate data @@ -41,7 +42,6 @@ # loss matrix M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) -M /= M.max() ############################################################################## @@ -69,18 +69,12 @@ epsilon = 0.1 # entropy parameter alpha = 1.0 # Unbalanced KL relaxation parameter -Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M / M.max(), epsilon, alpha, verbose=True) pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn") - pl.show() - -# %% -# plot the transported mass -# ------------------------- - pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") pl.plot(x, b, "r", label="Target distribution") @@ -88,14 +82,18 @@ pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", Gs.sum()) ############################################################################## -# Solve Unbalanced OT -# ------------------------- +# Solve Unbalanced OT in closed form +# ----------------------------------- alpha = 1.0 # Unbalanced KL relaxation parameter -Gs = ot.unbalanced.mm_unbalanced(a, b, M, alpha, verbose=False) + +Gs = ot.unbalanced.mm_unbalanced(a, b, M / M.max(), alpha, verbose=False) pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") @@ -104,22 +102,21 @@ pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", Gs.sum()) ############################################################################## -# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) +# Solve 1D UOT with Frank-Wolfe # ----------------------------- -alpha = 10000.0 # Unbalanced KL relaxation parameter +alpha = M.max() # Unbalanced KL relaxation parameter a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d( x, x, alpha, u_weights=a, v_weights=b ) - -# plot the transported mass -# ------------------------- - pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") pl.plot(x, b, "r", label="Target distribution") @@ -127,43 +124,16 @@ pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") +pl.show() - -############################################################################## -# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) -# ----------------------------- -import torch - -alpha = 10000.0 # Unbalanced KL relaxation parameter - -a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( - torch.tensor(x.reshape((n, 1)), dtype=torch.float64), - torch.tensor(x.reshape((n, 1)), dtype=torch.float64), - alpha, - torch.tensor(a, dtype=torch.float64), - torch.tensor(b, dtype=torch.float64), - mode="backprop", -) - - -# plot the transported mass -# ------------------------- - -pl.figure(4, figsize=(6.4, 3)) -pl.plot(x, a, "b", label="Source distribution") -pl.plot(x, b, "r", label="Target distribution") -pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source") -pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") -pl.legend(loc="upper right") -pl.title("Distributions and transported mass for UOT") +print("Mass of reweighted marginals:", a_reweighted.sum()) ############################################################################## -# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) +# Solve 1D UOT with Frank-Wolfe # ----------------------------- -import torch -alpha = 10000.0 # (10000, 10000) # Unbalanced KL relaxation parameter +alpha = M.max() # Unbalanced KL relaxation parameter a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( torch.tensor(x.reshape((n, 1)), dtype=torch.float64), @@ -181,36 +151,10 @@ pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") pl.plot(x, b, "r", label="Target distribution") -pl.fill(x, a_reweighted.detach().numpy(), "b", alpha=0.5, label="Transported source") -pl.fill(x, b_reweighted.detach().numpy(), "r", alpha=0.5, label="Transported target") +pl.fill(x, a_reweighted.numpy(), "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted.numpy(), "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") +pl.show() - -############################################################################## -# Solve 1D UOT with Frank-Wolfe (TODO: check the behaviour) -# ----------------------------- -import torch - -alpha = 10000.0 # (10000, 10000) # Unbalanced KL relaxation parameter - -a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( - torch.tensor(x.reshape((n, 1)), dtype=torch.float32), - torch.tensor(x.reshape((n, 1)), dtype=torch.float32), - alpha, - torch.tensor(a, dtype=torch.float32), - torch.tensor(b, dtype=torch.float32), - mode="backprop", -) - - -# plot the transported mass -# ------------------------- - -pl.figure(4, figsize=(6.4, 3)) -pl.plot(x, a, "b", label="Source distribution") -pl.plot(x, b, "r", label="Target distribution") -pl.fill(x, a_reweighted.detach().numpy(), "b", alpha=0.5, label="Transported source") -pl.fill(x, b_reweighted.detach().numpy(), "r", alpha=0.5, label="Transported target") -pl.legend(loc="upper right") -pl.title("Distributions and transported mass for UOT") +print("Mass of reweighted marginals:", a_reweighted.sum()) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index f8a64ec58..f27bd7dc0 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -617,8 +617,8 @@ def emd_1d_dual_backprop( loss.backward() return ( - u_weights.grad, - v_weights.grad, + u_weights.grad.detach(), + v_weights.grad.detach(), cost_output.detach(), ) # value can not be backward anymore elif nx.__name__ == "jax": diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py index 247059ff5..700c69727 100644 --- a/ot/unbalanced/_sliced.py +++ b/ot/unbalanced/_sliced.py @@ -340,8 +340,8 @@ def unbalanced_sliced_ot( uot_loss = ( sot_loss - + reg_m1 * nx.kl_div(a_reweighted, a) - + reg_m2 * nx.kl_div(b_reweighted, b) + + reg_m1 * nx.kl_div(a_reweighted, a, mass=True) + + reg_m2 * nx.kl_div(b_reweighted, b, mass=True) ) if log: diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py index 5d721d750..f705e7b46 100644 --- a/ot/unbalanced/_solver_1d.py +++ b/ot/unbalanced/_solver_1d.py @@ -14,11 +14,41 @@ def rescale_potentials(f, g, a, b, rho1, rho2, nx): r""" - TODO + Find the optimal :math: `\lambda` in the translation invariant dual of UOT + with KL regularization and returns it, see Proposition 2 in :ref:`[73] `. + + Parameters + ---------- + f: array-like, shape (n, ...) + first dual potential + g: array-like, shape (m, ...) + second dual potential + a: array-like, shape (n, ...) + weights of the first empirical distribution + b: array-like, shape (m, ...) + weights of the second empirical distribution + rho1: float + Marginal relaxation term for the first marginal + rho2: float + Marginal relaxation term for the second marginal + nx: module + backend module + + Returns + ------- + transl: array-like, shape (...) + optimal translation + + .. _references-uot: + References + ---------- + .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. """ tau = (rho1 * rho2) / (rho1 + rho2) - num = nx.logsumexp(-f / rho1 + nx.log(a)) - denom = nx.logsumexp(-g / rho2 + nx.log(b)) + num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0) + denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0) transl = tau * (num - denom) return transl @@ -32,18 +62,18 @@ def uot_1d( p=1, require_sort=True, numItermax=10, - stopThr=1e-6, mode="icdf", + returnCost="linear", log=False, ): r""" - TODO, TOTEST, seems not very stable? - Solves the 1D unbalanced OT problem with KL regularization. The function implements the Frank-Wolfe algorithm to solve the dual problem, - as proposed in [73]. + as proposed in :ref:`[73] `. - TODO: add math equation + The unbalanced OT problem reads + .. math: + \mathrm{UOT}(\mu,\nu) = \min_{\gamma \in \mathcal{M}_{+}(\mathbb{R}\times\mathbb{R})} W_2^2(\pi^1_\#\gamma,\pi^2_\#\gamma) + \mathrm{reg_{m}}_1 \mathrm{KL}(\pi^1_\#\gamma|\mu) + \mathrm{reg_{m}}_2 \mathrm{KL}(\pi^2_\#\gamma|\nu). Parameters ---------- @@ -55,12 +85,12 @@ def uot_1d( Marginal relaxation term. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. - The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + (TODO?) The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. For semi-relaxed case, use either :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. If :math:`\mathrm{reg_{m}}` is an array, - it must have the same backend as input arrays `(a, b, M)`. + it must have the same backend as input arrays `(a, b)`. u_weights: array-like, shape (n, ...), optional weights of the first empirical distribution, if None then uniform weights are used v_weights: array-like, shape (m, ...), optional @@ -74,6 +104,9 @@ def uot_1d( mode: str, optional "icdf" for inverse CDF, "backprop" for backpropagation mode. Default is "icdf". + returnCost: string, optional (default = "linear") + If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. + If `returnCost` = "total", then return the total unbalanced OT loss. log: bool, optional Returns @@ -83,8 +116,9 @@ def uot_1d( v_reweighted: array-like shape (m, ...) Second marginal reweighted loss: float/array-like, shape (...) - the batched 1D UOT + The batched 1D UOT + .. _references-uot: References --------- .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). @@ -128,15 +162,21 @@ def uot_1d( v_weights_sorted = nx.take_along_axis(v_weights, v_sorter, 0) f = nx.zeros(u_weights.shape, type_as=u_weights) + fd = nx.zeros(u_weights.shape, type_as=u_weights) g = nx.zeros(v_weights.shape, type_as=v_weights) + gd = nx.zeros(v_weights.shape, type_as=v_weights) for i in range(numItermax): + t = 2.0 / (2.0 + i - 1) + f = f + t * (fd - f) + g = g + t * (gd - g) + transl = rescale_potentials( f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx ) - f = f + transl - g = g - transl + f = f + transl[None] + g = g - transl[None] u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1) v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2) @@ -144,15 +184,15 @@ def uot_1d( full_mass = nx.sum(u_reweighted, axis=0) # Normalize weights - u_reweighted = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True) - v_reweighted = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True) + u_rescaled = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True) + v_rescaled = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True) if mode == "icdf": fd, gd, loss = emd_1d_dual( u_values_sorted, v_values_sorted, - u_weights=u_reweighted, - v_weights=v_reweighted, + u_weights=u_rescaled, + v_weights=v_rescaled, p=p, require_sort=False, ) @@ -160,15 +200,15 @@ def uot_1d( fd, gd, loss = emd_1d_dual_backprop( u_values_sorted, v_values_sorted, - u_weights=u_reweighted, - v_weights=v_reweighted, + u_weights=u_rescaled, + v_weights=v_rescaled, p=p, require_sort=False, ) - t = 2.0 / (2.0 + i) - f = f + t * (fd - f) - g = g + t * (gd - g) + # t = 2.0 / (2.0 + i) + # f = f + t * (fd - f) + # g = g + t * (gd - g) if require_sort: f = nx.take_along_axis(f, u_rev_sorter, 0) @@ -177,15 +217,20 @@ def uot_1d( v_reweighted = nx.take_along_axis(v_reweighted, v_rev_sorter, 0) # rescale OT loss - loss = loss * full_mass + linear_loss = loss * full_mass uot_loss = ( - loss - + reg_m1 * nx.kl_div(u_reweighted, u_weights) - + reg_m2 * nx.kl_div(v_reweighted, v_weights) + linear_loss + + reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True) + + reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True) ) + if returnCost == "linear": + out_loss = linear_loss + elif returnCost == "total": + out_loss = uot_loss + if log: - dico = {"f": f, "g": g} - return u_reweighted, v_reweighted, uot_loss, dico - return u_reweighted, v_reweighted, uot_loss + dico = {"f": f, "g": g, "total_cost": uot_loss, "linear_cost": linear_loss} + return u_reweighted, v_reweighted, out_loss, dico + return u_reweighted, v_reweighted, out_loss From c6301b8bd9037b2ce4cc01bff6cd531c0b26280a Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 13 Sep 2025 20:50:35 +0200 Subject: [PATCH 16/19] Example USOT vs SUOT --- README.md | 3 +- .../unbalanced-partial/plot_UOT_sliced.py | 153 +++++++++--------- ot/lp/solver_1d.py | 1 + ot/sliced.py | 1 + ot/unbalanced/_sliced.py | 12 +- ot/unbalanced/_solver_1d.py | 2 +- 6 files changed, 90 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index 3cc7d55d9..a42bda27e 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,10 @@ POT provides the following generic OT solvers: Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [80] * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html) [44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] diff --git a/examples/unbalanced-partial/plot_UOT_sliced.py b/examples/unbalanced-partial/plot_UOT_sliced.py index a7b0ab1ee..0d9c9233c 100644 --- a/examples/unbalanced-partial/plot_UOT_sliced.py +++ b/examples/unbalanced-partial/plot_UOT_sliced.py @@ -7,11 +7,11 @@ This example illustrates the behavior of Sliced UOT versus Unbalanced Sliced OT. -The first one removes outliers on each sliced while the second one +The first one removes outliers on each slice while the second one removes outliers of the original marginals. """ -# Author: +# Author: Clément Bonet # # License: MIT License @@ -94,9 +94,8 @@ def make_spiral(n_samples, noise=0.5): dir = np.array([(np.cos(theta), np.sin(theta)) for theta in thetas]) dir_torch = torch.from_numpy(dir).type(torch.float) - -Xps = torch.dot(Xs_torch, dir_torch.T) # shape (n, n_projs) -Xpt = torch.dot(Xt_torch, dir_torch.T) +Xps = (Xs_torch @ dir_torch.T).T # shape (n_projs, n) +Xpt = (Xt_torch @ dir_torch.T).T ############################################################################## # Compute SUOT and USOT @@ -139,8 +138,8 @@ def make_spiral(n_samples, noise=0.5): ############################################################################## -# Plot reweighted distributions on several slices -# ------------- +# Utils plot +# ---------- # %% @@ -157,8 +156,62 @@ def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): return np.exp(log_pdf) -c1 = np.array(mpl.colors.to_rgb("lightcoral")) -c2 = np.array(mpl.colors.to_rgb("steelblue")) +def plot_slices( + col, nb_slices, x_grid, Xps, Xpt, Xps_weights, Xpt_weights, method, rho1, rho2 +): + for i in range(nb_slices): + ax = plt.subplot2grid((nb_slices, 3), (i, col)) + if len(Xps_weights.shape) > 1: # SUOT + weights_src = Xps_weights[i * offset_degree, :].cpu().numpy() + weights_tgt = Xpt_weights[i * offset_degree, :].cpu().numpy() + else: # USOT + weights_src = Xps_weights.cpu().numpy() + weights_tgt = Xpt_weights.cpu().numpy() + + samples_src = Xps[i * offset_degree, :].cpu().numpy() + samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() + + pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) + pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) + pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) + pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) + + ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) + + ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) + + ax.set_xlim(xlim_min, xlim_max) + + if col == 1: + ax.set_ylabel( + r"$\theta=${}$^o$".format(i * offset_degree), + color=colors[i], + fontsize=13, + ) + + ax.set_yticks([]) + ax.set_xticks([]) + + ax.set_xlabel( + r"{} $\rho_1={}$ $\rho_2={}$".format(method, rho1, rho2), fontsize=13 + ) + + +############################################################################## +# Plot reweighted distributions on several slices +# ------------- +# We plot the reweighted distributions on several slices. We see that for SUOT, +# the mode of outliers is kept of some slices (e.g. for :math:`\theta=120°`) while USOT +# is able to get rid of the outlier mode. + +# %% + +c1 = np.array(mpl.colors.to_rgb("red")) +c2 = np.array(mpl.colors.to_rgb("blue")) # define plotting grid xlim_min = -3 @@ -167,7 +220,7 @@ def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): bw = 0.05 # visu parameters -nb_slices = 6 +nb_slices = 3 # 4 offset_degree = int(180 / nb_slices) delta_degree = np.pi / nb_slices @@ -176,9 +229,10 @@ def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): X1 = np.array([-4, 0]) X2 = np.array([4, 0]) -fig = plt.figure(figsize=(28, 8)) -ax1 = plt.subplot2grid((nb_slices, 3), (0, 0), rowspan=nb_slices) +fig = plt.figure(figsize=(9, 3)) + +ax1 = plt.subplot2grid((nb_slices, 3), (0, 0), rowspan=nb_slices) for i in range(nb_slices): R = get_rot(delta_degree * (-i)) @@ -197,82 +251,25 @@ def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): ax1.plot( [X1_r[0], X2_r[0]], [X1_r[1], X2_r[1]], color=colors[i], alpha=0.8, zorder=0 ) + ax1.scatter(Xs[:, 0], Xs[:, 1], zorder=1, color=c2, label="Source data") ax1.scatter(Xt[:, 0], Xt[:, 1], zorder=1, color=c1, label="Target data") ax1.set_xlim([-3, 3]) ax1.set_ylim([-3, 3]) ax1.set_yticks([]) ax1.set_xticks([]) -ax1.legend(loc="best", fontsize=18) -ax1.set_xlabel("Original distributions", fontsize=22) +# ax1.legend(loc='best',fontsize=13) +ax1.set_xlabel("Original distributions", fontsize=13) + -# ***** plot SUOT fig.subplots_adjust(hspace=0) -fig.subplots_adjust(wspace=0.1) +fig.subplots_adjust(wspace=0.15) -for i in range(nb_slices): - ax = plt.subplot2grid((nb_slices, 3), (i, 1)) - weights_src = A_SUOT[i * offset_degree, :].cpu().numpy() - weights_tgt = B_SUOT[i * offset_degree, :].cpu().numpy() - samples_src = Xps[i * offset_degree, :].cpu().numpy() - samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() - pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) - pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) - pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) - pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) - - ax.scatter(samples_src, [-0.2] * samples_src.shape[0], color=c2, s=2) - ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) - ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) - ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) - - ax.scatter(samples_tgt, [-0.2] * samples_tgt.shape[0], color=c1, s=2) - ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) - ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) - ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) - - # frac_mass = int(100*weights_src.sum()) - # plt.text(.9, .9, '% mass={}%'.format(frac_mass), ha='right', va='top', color='red',fontsize=14, transform=ax.transAxes) - - ax.set_xlim(xlim_min, xlim_max) - ax.set_ylabel( - r"$\theta=${}$^o$".format(i * offset_degree), color=colors[i], fontsize=16 - ) - ax.set_yticks([]) - ax.set_yticks([]) -ax.set_xlabel( - r"SUOT $\rho_1={}$ $\rho_2={}$".format(rho1_SUOT, rho2_SUOT), fontsize=22 +plot_slices( + 1, nb_slices, x_grid, Xps, Xpt, A_SUOT, B_SUOT, "SUOT", rho1_SUOT, rho2_SUOT ) -# ***** plot USOT - -for i in range(nb_slices): - ax = plt.subplot2grid((nb_slices, 3), (i, 2)) - weights_src = A_USOT.cpu().numpy() - weights_tgt = B_USOT.cpu().numpy() - samples_src = Xps[i * offset_degree, :].cpu().numpy() - samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() - pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) - pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) - pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) - pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) - - ax.scatter(samples_src, [-0.2] * samples_src.shape[0], color=c2, s=2) - ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) - ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) - ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) - - ax.scatter(samples_tgt, [-0.2] * samples_tgt.shape[0], color=c1, s=2) - ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) - ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) - ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) - - ax.set_xlim(xlim_min, xlim_max) - ax.set_ylabel( - r"$\theta=${}$^o$".format(i * offset_degree), color=colors[i], fontsize=16 - ) - ax.set_yticks([]) -ax.set_xlabel( - r"USOT $\rho_1={}$ $\rho_2={}$".format(rho1_USOT, rho2_USOT), fontsize=22 +plot_slices( + 2, nb_slices, x_grid, Xps, Xpt, A_USOT, B_USOT, "USOT", rho1_USOT, rho2_USOT ) plt.show() diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index f27bd7dc0..47f2aeb09 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -5,6 +5,7 @@ # Author: Remi Flamary # Author: Nicolas Courty +# Author: Clément Bonet # # License: MIT License diff --git a/ot/sliced.py b/ot/sliced.py index 3cf2002e7..29c499b2e 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -6,6 +6,7 @@ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary +# Clément Bonet # # License: MIT License diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py index 700c69727..938b0fe89 100644 --- a/ot/unbalanced/_sliced.py +++ b/ot/unbalanced/_sliced.py @@ -3,7 +3,7 @@ Sliced Unbalanced OT solvers """ -# Author: +# Author: Clément Bonet # # License: MIT License @@ -119,7 +119,15 @@ def sliced_unbalanced_ot( X_t_projections = nx.dot(X_t, projections) a_reweighted, b_reweighted, projected_uot = uot_1d( - X_s_projections, X_t_projections, reg_m, a, b, p, require_sort=True, mode=mode + X_s_projections, + X_t_projections, + reg_m, + a, + b, + p, + require_sort=True, + mode=mode, + numItermax=numItermax, ) res = nx.mean(projected_uot) ** (1.0 / p) diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py index f705e7b46..5d174b7f0 100644 --- a/ot/unbalanced/_solver_1d.py +++ b/ot/unbalanced/_solver_1d.py @@ -3,7 +3,7 @@ 1D Unbalanced OT solvers """ -# Author: +# Author: Clément Bonet # # License: MIT License From 504c07afb902a1a2085613e7ed3dab4254c95b42 Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 14 Sep 2025 18:57:41 +0200 Subject: [PATCH 17/19] Center dual potentials --- ignore-words.txt | 3 ++- ot/lp/_network_simplex.py | 33 ++++++++++++++++++++++++--------- ot/lp/solver_1d.py | 15 ++++++++++----- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/ignore-words.txt b/ignore-words.txt index 00c1f5edb..573400137 100644 --- a/ignore-words.txt +++ b/ignore-words.txt @@ -6,4 +6,5 @@ wass ccompiler ist lik -ges \ No newline at end of file +ges +mapp diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 492e4c7ac..cf7025301 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -44,31 +44,46 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): Parameters ---------- - alpha0 : (ns,) numpy.ndarray, float64 + alpha0 : (ns, ...) numpy.ndarray, float64 Source dual potential - beta0 : (nt,) numpy.ndarray, float64 + beta0 : (nt, ...) numpy.ndarray, float64 Target dual potential - a : (ns,) numpy.ndarray, float64 + a : (ns, ...) numpy.ndarray, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt, ....) numpy.ndarray, float64 Target histogram (uniform weight if empty list) Returns ------- - alpha : (ns,) numpy.ndarray, float64 + alpha : (ns, ...) numpy.ndarray, float64 Source centered dual potential - beta : (nt,) numpy.ndarray, float64 + beta : (nt, ...) numpy.ndarray, float64 Target centered dual potential """ + if a is not None and b is not None: + nx = get_backend(alpha0, beta0, a, b) + else: + nx = get_backend(alpha0, beta0) + + n = alpha0.shape[0] + m = beta0.shape[0] + # if no weights are provided, use uniform if a is None: - a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + a = nx.full(alpha0.shape, 1.0 / n, type_as=alpha0) + elif a.ndim != alpha0.ndim: + a = nx.repeat(a[..., None], alpha0.shape[-1], -1) + if b is None: - b = np.ones(beta0.shape[0]) / beta0.shape[0] + b = nx.full(beta0.shape, 1.0 / m, type_as=beta0) + elif b.ndim != beta0.ndim: + b = nx.repeat(b[..., None], beta0.shape[-1], -1) # compute constant that balances the weighted sums of the duals - c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + ips = nx.sum(b * beta0, axis=0) - nx.sum(a * alpha0, axis=0) + denom = nx.sum(a, axis=0) + nx.sum(b, axis=0) + c = ips / denom # update duals alpha = alpha0 + c diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 47f2aeb09..609008f45 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -15,6 +15,7 @@ from .emd_wrap import emd_1d_sorted from ..backend import get_backend from ..utils import list_to_array +from ._network_simplex import center_ot_dual def quantile_function(qs, cws, xs, return_index=False): @@ -541,6 +542,8 @@ def emd_1d_dual( v_rev_sorter = nx.argsort(v_sorter, 0) g = nx.take_along_axis(g, v_rev_sorter, 0) + f, g = center_ot_dual(f, g, u_weights, v_weights) + return f, g, loss @@ -617,11 +620,11 @@ def emd_1d_dual_backprop( loss = cost_output.sum() loss.backward() - return ( - u_weights.grad.detach(), - v_weights.grad.detach(), - cost_output.detach(), - ) # value can not be backward anymore + f, g = center_ot_dual( + u_weights.grad.detach(), v_weights.grad.detach(), u_weights, v_weights + ) + + return f, g, cost_output.detach() # value can not be backward anymore elif nx.__name__ == "jax": import jax @@ -634,6 +637,8 @@ def ot_1d(a, b): cost_output = wasserstein_1d( u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort ) + + f, g = center_ot_dual(f, g, u_weights, v_weights) return f, g, cost_output From 812b4da159e134fec9cbbaa5997ac92f9334ded8 Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 14 Sep 2025 23:09:04 +0200 Subject: [PATCH 18/19] up tests --- .../unbalanced-partial/plot_UOT_sliced.py | 4 ++-- ot/lp/solver_1d.py | 20 +++++++++++++---- test/unbalanced/test_1d_solver.py | 22 +++++++++++++------ test/unbalanced/test_sliced.py | 2 +- 4 files changed, 34 insertions(+), 14 deletions(-) diff --git a/examples/unbalanced-partial/plot_UOT_sliced.py b/examples/unbalanced-partial/plot_UOT_sliced.py index 0d9c9233c..d5937a71d 100644 --- a/examples/unbalanced-partial/plot_UOT_sliced.py +++ b/examples/unbalanced-partial/plot_UOT_sliced.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -=============================== +=================================== Sliced Unbalanced optimal transport -=============================== +=================================== This example illustrates the behavior of Sliced UOT versus Unbalanced Sliced OT. diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 609008f45..155d834b0 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -612,16 +612,28 @@ def emd_1d_dual_backprop( v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) if nx.__name__ == "torch": - u_weights.requires_grad_(True) - v_weights.requires_grad_(True) + u_weights_diff = nx.copy(u_weights) + v_weights_diff = nx.copy(v_weights) + + u_weights_diff.requires_grad_(True) + v_weights_diff.requires_grad_(True) + cost_output = wasserstein_1d( - u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort + u_values, + v_values, + u_weights_diff, + v_weights_diff, + p=p, + require_sort=require_sort, ) loss = cost_output.sum() loss.backward() f, g = center_ot_dual( - u_weights.grad.detach(), v_weights.grad.detach(), u_weights, v_weights + u_weights_diff.grad.detach(), + v_weights_diff.grad.detach(), + u_weights, + v_weights, ) return f, g, cost_output.detach() # value can not be backward anymore diff --git a/test/unbalanced/test_1d_solver.py b/test/unbalanced/test_1d_solver.py index 622f194c1..3c885bce9 100644 --- a/test/unbalanced/test_1d_solver.py +++ b/test/unbalanced/test_1d_solver.py @@ -1,6 +1,6 @@ """Tests for module 1D Unbalanced OT""" -# Author: +# Author: Clément Bonet # # License: MIT License @@ -11,8 +11,6 @@ def test_uot_1d(nx): - pass - n_samples = 20 # nb samples rng = np.random.RandomState(42) @@ -25,16 +23,26 @@ def test_uot_1d(nx): reg_m = 1.0 M = ot.dist(xs, xt) - M = M / M.max() + # M = M / M.max() a, b, M = nx.from_numpy(a_np, b_np, M) + xs, xt = nx.from_numpy(xs, xt) loss_mm = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div="kl") - print("??", loss_mm) + print("?", nx.__name__) + + if nx.__name__ != "jax": + f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="icdf", numItermax=100) + print("!! ", loss_1d.item()) + np.testing.assert_allclose(loss_1d, loss_mm) if nx.__name__ in ["jax", "torch"]: - f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="backprop") + print("??", loss_mm.item()) + + f, g, loss_1d = ot.unbalanced.uot_1d( + xs, xt, reg_m, mode="backprop", numItermax=100 + ) - print("???", loss_1d[0]) + print("???", loss_1d.item()) np.testing.assert_allclose(loss_1d, loss_mm) diff --git a/test/unbalanced/test_sliced.py b/test/unbalanced/test_sliced.py index 15a7a72b2..bdd917f19 100644 --- a/test/unbalanced/test_sliced.py +++ b/test/unbalanced/test_sliced.py @@ -1,6 +1,6 @@ """Tests for module sliced Unbalanced OT""" -# Author: +# Author: Clément Bonet # # License: MIT License From 801aa89e6be4f5c5b46972ddb9d12d33bec2a0cd Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 5 Oct 2025 13:44:28 +0200 Subject: [PATCH 19/19] up citation --- README.md | 2 +- RELEASES.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 907b6f74b..3e69af448 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ POT provides the following generic OT solvers: * [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. -* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [80] +* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [82] * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html) [44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] diff --git a/RELEASES.md b/RELEASES.md index 4a1d19445..dfca17a0d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,6 +1,6 @@ # Releases -## 0.9.7 +## 0.9.7dev #### New features - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #)