Skip to content

Commit

Permalink
[Fix] Prevent line search from evaluating cost outside of the interpo…
Browse files Browse the repository at this point in the history
…lation range (#504)

* Explicitly check that SinkhornL1l2Transport.fit works with no warnings

* Default value for alpha_min is set to 0

* Fix random_state for SinkhornL1l2Transport test

* Mention changes in releases

---------

Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
  • Loading branch information
kachayev and rflamary authored Sep 21, 2023
1 parent 9e74f2e commit 526b72f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)


## 0.9.1
Expand Down
17 changes: 12 additions & 5 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
Expand Down Expand Up @@ -56,7 +56,7 @@ def line_search_armijo(
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
alpha_min : float, optional
alpha_min : float, default=0.
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
Expand Down Expand Up @@ -89,6 +89,14 @@ def line_search_armijo(
fc = [0]

def phi(alpha1):
# it's necessary to check boundary condition here for the coefficient
# as the callback could be evaluated for negative value of alpha by
# `scalar_search_armijo` function here:
#
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
#
# see more details https://github.com/PythonOT/POT/issues/502
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
# The callable function operates on nx backend
fc[0] += 1
alpha10 = nx.from_numpy(alpha1)
Expand All @@ -109,13 +117,12 @@ def phi(alpha1):

derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min)

if alpha is None:
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
alpha = np.clip(alpha, alpha_min, alpha_max)
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)


Expand Down
11 changes: 7 additions & 4 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
import pytest
import warnings

import ot
from ot.datasets import make_data_classif
Expand Down Expand Up @@ -158,15 +159,17 @@ def test_sinkhorn_l1l2_transport_class(nx):
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)

otda = ot.da.SinkhornL1l2Transport()
otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)

# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
with warnings.catch_warnings():
warnings.simplefilter("error")
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert hasattr(otda, "coupling_")
assert hasattr(otda, "log_")
Expand Down

0 comments on commit 526b72f

Please sign in to comment.