Skip to content

Commit

Permalink
[MRG] update LP barycenter with new scipy solvers (#537)
Browse files Browse the repository at this point in the history
* update lp barycenter with new scipy solvers

* use default solver in exmaple qnd add release info
  • Loading branch information
rflamary authored Oct 26, 2023
1 parent a9de7a0 commit e7cba02
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
+ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533)
+ New API for Gromov-Wasserstein solvers with `ot.solve_gromov` function (PR #536)
+ New LP solvers from scipy used by default for LP barycenter (PR #537)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
6 changes: 3 additions & 3 deletions examples/barycenters/plot_barycenter_lp_vs_entropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@


ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
bary_wass2 = ot.lp.barycenter(A, M, weights)
ot.toc()

pl.figure(2)
Expand Down Expand Up @@ -149,7 +149,7 @@


ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
bary_wass2 = ot.lp.barycenter(A, M, weights)
ot.toc()


Expand Down Expand Up @@ -223,7 +223,7 @@


ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
bary_wass2 = ot.lp.barycenter(A, M, weights)
ot.toc()


Expand Down
6 changes: 3 additions & 3 deletions ot/lp/cvx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def scipy_sparse_to_spmatrix(A):
return SP


def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='highs-ipm'):
r"""Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
Expand Down Expand Up @@ -115,13 +115,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
A_eq = sps.vstack((A_eq1, A_eq2))
b_eq = np.concatenate((b_eq1, b_eq2))

if not cvxopt or solver in ['interior-point']:
if not cvxopt or solver in ['interior-point', 'highs', 'highs-ipm', 'highs-ds']:
# cvxopt not installed or interior point

if solver is None:
solver = 'interior-point'

options = {'sparse': True, 'disp': verbose}
options = {'disp': verbose}
sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
options=options)
x = sol.x
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.20
scipy>=1.3
scipy>=1.6
matplotlib
autograd
pymanopt
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
scripts=[],
data_files=[],
setup_requires=["oldest-supported-numpy", "cython>=0.23"],
install_requires=["numpy>=1.16", "scipy>=1.0"],
install_requires=["numpy>=1.16", "scipy>=1.6"],
python_requires=">=3.6",
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down

0 comments on commit e7cba02

Please sign in to comment.