Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ POT provides the following generic OT solvers:
Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [82]
* [Wasserstein distance on the
circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html)
[44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
Expand Down Expand Up @@ -367,7 +368,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.

[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.

[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.

Expand Down Expand Up @@ -449,5 +450,4 @@ Artificial Intelligence.

[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).


```
[82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research.
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Releases

## 0.9.7dev

#### New features
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #)

## 0.9.6.post1

*September 2025*
Expand Down
80 changes: 75 additions & 5 deletions examples/unbalanced-partial/plot_UOT_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss
import torch

##############################################################################
# Generate data
Expand All @@ -41,7 +42,6 @@

# loss matrix
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
M /= M.max()


##############################################################################
Expand Down Expand Up @@ -69,17 +69,31 @@

epsilon = 0.1 # entropy parameter
alpha = 1.0 # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M / M.max(), epsilon, alpha, verbose=True)

pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn")
pl.show()

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source")
pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", Gs.sum())

# %%
# plot the transported mass
# -------------------------

##############################################################################
# Solve Unbalanced OT in closed form
# -----------------------------------

alpha = 1.0 # Unbalanced KL relaxation parameter

Gs = ot.unbalanced.mm_unbalanced(a, b, M / M.max(), alpha, verbose=False)

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
Expand All @@ -88,3 +102,59 @@
pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", Gs.sum())


##############################################################################
# Solve 1D UOT with Frank-Wolfe
# -----------------------------

alpha = M.max() # Unbalanced KL relaxation parameter

a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d(
x, x, alpha, u_weights=a, v_weights=b
)

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source")
pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", a_reweighted.sum())


##############################################################################
# Solve 1D UOT with Frank-Wolfe
# -----------------------------

alpha = M.max() # Unbalanced KL relaxation parameter

a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot(
torch.tensor(x.reshape((n, 1)), dtype=torch.float64),
torch.tensor(x.reshape((n, 1)), dtype=torch.float64),
alpha,
torch.tensor(a, dtype=torch.float64),
torch.tensor(b, dtype=torch.float64),
mode="backprop",
)


# plot the transported mass
# -------------------------

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, a_reweighted.numpy(), "b", alpha=0.5, label="Transported source")
pl.fill(x, b_reweighted.numpy(), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", a_reweighted.sum())
Loading
Loading