From 76450dddf8dd62b9714b72e99ae075516246d433 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 17:35:36 +0200 Subject: [PATCH] [MRG] Backend for optim (#282) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary --- ot/backend.py | 118 +++++++++++++++++++++----------- ot/lp/__init__.py | 4 +- ot/optim.py | 155 ++++++++++++++++++++++++++----------------- test/test_backend.py | 22 +++--- test/test_optim.py | 78 +++++++++++++++++----- test/test_ot.py | 6 +- test/test_utils.py | 7 -- 7 files changed, 250 insertions(+), 140 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index a4a4757ce..876b96a84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -123,7 +123,7 @@ def zeros(self, shape, type_as=None): r""" Creates a tensor full of zeros. - This function follow the api from :any:`numpy.zeros` + This function follows the api from :any:`numpy.zeros` See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html """ @@ -133,7 +133,7 @@ def ones(self, shape, type_as=None): r""" Creates a tensor full of ones. - This function follow the api from :any:`numpy.ones` + This function follows the api from :any:`numpy.ones` See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html """ @@ -143,7 +143,7 @@ def arange(self, stop, start=0, step=1, type_as=None): r""" Returns evenly spaced values within a given interval. - This function follow the api from :any:`numpy.arange` + This function follows the api from :any:`numpy.arange` See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html """ @@ -153,7 +153,7 @@ def full(self, shape, fill_value, type_as=None): r""" Creates a tensor with given shape, filled with given value. - This function follow the api from :any:`numpy.full` + This function follows the api from :any:`numpy.full` See: https://numpy.org/doc/stable/reference/generated/numpy.full.html """ @@ -163,7 +163,7 @@ def eye(self, N, M=None, type_as=None): r""" Creates the identity matrix of given size. - This function follow the api from :any:`numpy.eye` + This function follows the api from :any:`numpy.eye` See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html """ @@ -173,7 +173,7 @@ def sum(self, a, axis=None, keepdims=False): r""" Sums tensor elements over given dimensions. - This function follow the api from :any:`numpy.sum` + This function follows the api from :any:`numpy.sum` See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html """ @@ -183,7 +183,7 @@ def cumsum(self, a, axis=None): r""" Returns the cumulative sum of tensor elements over given dimensions. - This function follow the api from :any:`numpy.cumsum` + This function follows the api from :any:`numpy.cumsum` See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html """ @@ -193,7 +193,7 @@ def max(self, a, axis=None, keepdims=False): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amax` + This function follows the api from :any:`numpy.amax` See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html """ @@ -203,7 +203,7 @@ def min(self, a, axis=None, keepdims=False): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amin` + This function follows the api from :any:`numpy.amin` See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html """ @@ -213,7 +213,7 @@ def maximum(self, a, b): r""" Returns element-wise maximum of array elements. - This function follow the api from :any:`numpy.maximum` + This function follows the api from :any:`numpy.maximum` See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html """ @@ -223,7 +223,7 @@ def minimum(self, a, b): r""" Returns element-wise minimum of array elements. - This function follow the api from :any:`numpy.minimum` + This function follows the api from :any:`numpy.minimum` See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html """ @@ -233,7 +233,7 @@ def dot(self, a, b): r""" Returns the dot product of two tensors. - This function follow the api from :any:`numpy.dot` + This function follows the api from :any:`numpy.dot` See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html """ @@ -243,7 +243,7 @@ def abs(self, a): r""" Computes the absolute value element-wise. - This function follow the api from :any:`numpy.absolute` + This function follows the api from :any:`numpy.absolute` See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html """ @@ -253,7 +253,7 @@ def exp(self, a): r""" Computes the exponential value element-wise. - This function follow the api from :any:`numpy.exp` + This function follows the api from :any:`numpy.exp` See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html """ @@ -263,7 +263,7 @@ def log(self, a): r""" Computes the natural logarithm, element-wise. - This function follow the api from :any:`numpy.log` + This function follows the api from :any:`numpy.log` See: https://numpy.org/doc/stable/reference/generated/numpy.log.html """ @@ -273,7 +273,7 @@ def sqrt(self, a): r""" Returns the non-ngeative square root of a tensor, element-wise. - This function follow the api from :any:`numpy.sqrt` + This function follows the api from :any:`numpy.sqrt` See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html """ @@ -283,7 +283,7 @@ def power(self, a, exponents): r""" First tensor elements raised to powers from second tensor, element-wise. - This function follow the api from :any:`numpy.power` + This function follows the api from :any:`numpy.power` See: https://numpy.org/doc/stable/reference/generated/numpy.power.html """ @@ -293,7 +293,7 @@ def norm(self, a): r""" Computes the matrix frobenius norm. - This function follow the api from :any:`numpy.linalg.norm` + This function follows the api from :any:`numpy.linalg.norm` See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html """ @@ -303,7 +303,7 @@ def any(self, a): r""" Tests whether any tensor element along given dimensions evaluates to True. - This function follow the api from :any:`numpy.any` + This function follows the api from :any:`numpy.any` See: https://numpy.org/doc/stable/reference/generated/numpy.any.html """ @@ -313,7 +313,7 @@ def isnan(self, a): r""" Tests element-wise for NaN and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isnan` + This function follows the api from :any:`numpy.isnan` See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html """ @@ -323,7 +323,7 @@ def isinf(self, a): r""" Tests element-wise for positive or negative infinity and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isinf` + This function follows the api from :any:`numpy.isinf` See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html """ @@ -333,7 +333,7 @@ def einsum(self, subscripts, *operands): r""" Evaluates the Einstein summation convention on the operands. - This function follow the api from :any:`numpy.einsum` + This function follows the api from :any:`numpy.einsum` See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html """ @@ -343,7 +343,7 @@ def sort(self, a, axis=-1): r""" Returns a sorted copy of a tensor. - This function follow the api from :any:`numpy.sort` + This function follows the api from :any:`numpy.sort` See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html """ @@ -353,7 +353,7 @@ def argsort(self, a, axis=None): r""" Returns the indices that would sort a tensor. - This function follow the api from :any:`numpy.argsort` + This function follows the api from :any:`numpy.argsort` See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html """ @@ -363,7 +363,7 @@ def searchsorted(self, a, v, side='left'): r""" Finds indices where elements should be inserted to maintain order in given tensor. - This function follow the api from :any:`numpy.searchsorted` + This function follows the api from :any:`numpy.searchsorted` See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html """ @@ -373,7 +373,7 @@ def flip(self, a, axis=None): r""" Reverses the order of elements in a tensor along given dimensions. - This function follow the api from :any:`numpy.flip` + This function follows the api from :any:`numpy.flip` See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html """ @@ -383,7 +383,7 @@ def clip(self, a, a_min, a_max): """ Limits the values in a tensor. - This function follow the api from :any:`numpy.clip` + This function follows the api from :any:`numpy.clip` See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html """ @@ -393,7 +393,7 @@ def repeat(self, a, repeats, axis=None): r""" Repeats elements of a tensor. - This function follow the api from :any:`numpy.repeat` + This function follows the api from :any:`numpy.repeat` See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html """ @@ -403,7 +403,7 @@ def take_along_axis(self, arr, indices, axis): r""" Gathers elements of a tensor along given dimensions. - This function follow the api from :any:`numpy.take_along_axis` + This function follows the api from :any:`numpy.take_along_axis` See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html """ @@ -413,7 +413,7 @@ def concatenate(self, arrays, axis=0): r""" Joins a sequence of tensors along an existing dimension. - This function follow the api from :any:`numpy.concatenate` + This function follows the api from :any:`numpy.concatenate` See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html """ @@ -423,7 +423,7 @@ def zero_pad(self, a, pad_width): r""" Pads a tensor. - This function follow the api from :any:`numpy.pad` + This function follows the api from :any:`numpy.pad` See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ @@ -433,7 +433,7 @@ def argmax(self, a, axis=None): r""" Returns the indices of the maximum values of a tensor along given dimensions. - This function follow the api from :any:`numpy.argmax` + This function follows the api from :any:`numpy.argmax` See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html """ @@ -443,7 +443,7 @@ def mean(self, a, axis=None): r""" Computes the arithmetic mean of a tensor along given dimensions. - This function follow the api from :any:`numpy.mean` + This function follows the api from :any:`numpy.mean` See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html """ @@ -453,7 +453,7 @@ def std(self, a, axis=None): r""" Computes the standard deviation of a tensor along given dimensions. - This function follow the api from :any:`numpy.std` + This function follows the api from :any:`numpy.std` See: https://numpy.org/doc/stable/reference/generated/numpy.std.html """ @@ -463,7 +463,7 @@ def linspace(self, start, stop, num): r""" Returns a specified number of evenly spaced values over a given interval. - This function follow the api from :any:`numpy.linspace` + This function follows the api from :any:`numpy.linspace` See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html """ @@ -473,7 +473,7 @@ def meshgrid(self, a, b): r""" Returns coordinate matrices from coordinate vectors (Numpy convention). - This function follow the api from :any:`numpy.meshgrid` + This function follows the api from :any:`numpy.meshgrid` See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html """ @@ -483,7 +483,7 @@ def diag(self, a, k=0): r""" Extracts or constructs a diagonal tensor. - This function follow the api from :any:`numpy.diag` + This function follows the api from :any:`numpy.diag` See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html """ @@ -493,7 +493,7 @@ def unique(self, a): r""" Finds unique elements of given tensor. - This function follow the api from :any:`numpy.unique` + This function follows the api from :any:`numpy.unique` See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html """ @@ -503,7 +503,7 @@ def logsumexp(self, a, axis=None): r""" Computes the log of the sum of exponentials of input elements. - This function follow the api from :any:`scipy.special.logsumexp` + This function follows the api from :any:`scipy.special.logsumexp` See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html """ @@ -513,12 +513,32 @@ def stack(self, arrays, axis=0): r""" Joins a sequence of tensors along a new dimension. - This function follow the api from :any:`numpy.stack` + This function follows the api from :any:`numpy.stack` See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html """ raise NotImplementedError() + def outer(self, a, b): + r""" + Computes the outer product between two vectors. + + This function follows the api from :any:`numpy.outer` + + See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html + """ + raise NotImplementedError() + + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -644,6 +664,9 @@ def searchsorted(self, a, v, side='left'): def flip(self, a, axis=None): return np.flip(a, axis) + def outer(self, a, b): + return np.outer(a, b) + def clip(self, a, a_min, a_max): return np.clip(a, a_min, a_max) @@ -686,6 +709,9 @@ def logsumexp(self, a, axis=None): def stack(self, arrays, axis=0): return np.stack(arrays, axis) + def reshape(self, a, shape): + return np.reshape(a, shape) + class JaxBackend(Backend): """ @@ -815,6 +841,9 @@ def searchsorted(self, a, v, side='left'): def flip(self, a, axis=None): return jnp.flip(a, axis) + def outer(self, a, b): + return jnp.outer(a, b) + def clip(self, a, a_min, a_max): return jnp.clip(a, a_min, a_max) @@ -857,6 +886,9 @@ def logsumexp(self, a, axis=None): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) + def reshape(self, a, shape): + return jnp.reshape(a, shape) + class TorchBackend(Backend): """ @@ -1035,6 +1067,9 @@ def flip(self, a, axis=None): else: return torch.flip(a, dims=axis) + def outer(self, a, b): + return torch.outer(a, b) + def clip(self, a, a_min, a_max): return torch.clamp(a, a_min, a_max) @@ -1091,3 +1126,6 @@ def logsumexp(self, a, axis=None): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b907b10a4..c6757d113 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a0, b0, M0 = a, b, M nx = get_backend(M0, a0, b0) - + # convert to numpy M = nx.to_numpy(M) a = nx.to_numpy(a) b = nx.to_numpy(b) - + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) diff --git a/ot/optim.py b/ot/optim.py index 03593430a..6822e4eba 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,6 +12,8 @@ from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn +from ot.utils import list_to_array +from .backend import get_backend # The corresponding scipy function does not work for matrices @@ -21,25 +23,25 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, """ Armijo linesearch function that works with matrices - find an approximate minimum of f(xk+alpha*pk) that satifies the + Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the armijo conditions. Parameters ---------- f : callable loss function - xk : ndarray + xk : array-like initial position - pk : ndarray + pk : array-like descent direction - gfk : ndarray - gradient of f at xk + gfk : array-like + gradient of `f` at :math:`x_k` old_fval : float - loss value at xk + loss value at :math:`x_k` args : tuple, optional - arguments given to f + arguments given to `f` c1 : float, optional - c1 const in armijo rule (>0) + :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) @@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, loss value at step alpha """ - xk = np.atleast_1d(xk) + + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) + + if len(xk.shape) == 0: + xk = nx.reshape(xk, (-1,)) + fc = [0] def phi(alpha1): @@ -65,7 +73,7 @@ def phi(alpha1): else: phi0 = old_fval - derphi0 = np.sum(pk * gfk) # Quickfix for matrices + derphi0 = nx.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) @@ -79,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations + Parameters ---------- cost : method Cost in the FW for the linesearch - G : ndarray, shape(ns,nt) + G : array-like, shape(ns,nt) The transport map at a given iteration of the FW - deltaG : ndarray (ns,nt) + deltaG : array-like (ns,nt) Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : ndarray (ns,nt) + Mi : array-like (ns,nt) Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at G + f_val : float + Value of the cost at `G` armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : ndarray (ns,ns), optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : array-like (ns,ns), optional Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : ndarray (nt,nt), optional + C2 : array-like (nt,nt), optional Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : ndarray (ns,nt) + Regularization parameter. Only used and necessary when armijo=False + Gc : array-like (ns,nt) Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used and necessary when armijo=False - M : ndarray (ns,nt), optional + constC : array-like (ns,nt) + Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False + M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + Returns ------- alpha : float - The optimal step size of the FW + The optimal step size of the FW fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices - dot1 = np.dot(C1, deltaG) - dot12 = dot1.dot(C2) - a = -2 * reg * np.sum(dot12 * deltaG) - b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2, constC) + else: + nx = get_backend(G, deltaG, C1, C2, constC, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) @@ -145,33 +162,33 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarray, shape (nt,) + b : array-like, shape (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 - G0 : ndarray, shape (ns,nt), optional + G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -196,6 +213,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log dictionary return only if log==True in parameters + .. _references-cg: References ---------- @@ -207,6 +225,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, ot.bregman.sinkhorn : Entropic regularized optimal transport """ + a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) loop = 1 @@ -214,12 +237,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg * f(G) + return nx.sum(M * G) + reg * f(G) f_val = cost(G) if log: @@ -240,7 +263,7 @@ def cost(G): # problem linearization Mi = M + reg * df(G) # set M positive - Mi += Mi.min() + Mi += nx.min(Mi) # solve linear program Gc = emd(a, b, Mi, numItermax=numItermaxEmd) @@ -286,36 +309,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] ` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarrayv (nt,) + b : array-like, (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg1 : float Entropic Regularization term >0 reg2 : float Second Regularization term >0 - G0 : ndarray, shape (ns, nt), optional + G0 : array-like, shape (ns, nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -337,9 +360,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log : dict log dictionary return only if log==True in parameters + + .. _references-gcg: References ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also @@ -347,6 +374,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ + a, b, M, G0 = list_to_array(a, b, M, G0) + nx = get_backend(a, b, M) loop = 1 @@ -354,12 +383,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) f_val = cost(G) if log: @@ -387,7 +416,7 @@ def cost(G): deltaG = Gc - G # line search - dcost = Mi + reg1 * (1 + np.log(G)) # ?? + dcost = Mi + reg1 * (1 + nx.log(G)) # ?? alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) G = G + alpha * deltaG @@ -419,9 +448,11 @@ def cost(G): def solve_1d_linesearch_quad(a, b, c): """ - For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + For any convex or non-convex 1d quadratic function `f`, solve the following problem: + .. math:: - \argmin f(x)=a*x^{2}+b*x+c + + arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index 859da5ab1..58532824f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -17,9 +17,6 @@ from ot.backend import get_backend, get_backend_list, to_numpy -backend_list = get_backend_list() - - def test_get_backend_list(): lst = get_backend_list() @@ -28,7 +25,6 @@ def test_get_backend_list(): assert isinstance(lst[0], ot.backend.NumpyBackend) -@pytest.mark.parametrize('nx', backend_list) def test_to_numpy(nx): v = nx.zeros(10) @@ -92,7 +88,6 @@ def test_get_backend(): get_backend(A, B2) -@pytest.mark.parametrize('nx', backend_list) def test_convert_between_backends(nx): A = np.zeros((3, 2)) @@ -180,6 +175,8 @@ def test_empty_backend(): nx.searchsorted(v, v) with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.outer(v, v) with pytest.raises(NotImplementedError): nx.clip(M, -1, 1) with pytest.raises(NotImplementedError): @@ -208,10 +205,11 @@ def test_empty_backend(): nx.logsumexp(M) with pytest.raises(NotImplementedError): nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) -@pytest.mark.parametrize('backend', backend_list) -def test_func_backends(backend): +def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) @@ -220,7 +218,7 @@ def test_func_backends(backend): lst_tot = [] - for nx in [ot.backend.NumpyBackend(), backend]: + for nx in [ot.backend.NumpyBackend(), nx]: print('Backend: ', nx.__name__) @@ -371,6 +369,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.outer(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('outer') + A = nx.clip(vb, 0, 1) lst_b.append(nx.to_numpy(A)) lst_name.append('clip') @@ -432,6 +434,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('stack') + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_optim.py b/test/test_optim.py index 94995d56d..4efd9b161 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -8,7 +8,7 @@ import ot -def test_conditional_gradient(): +def test_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -29,15 +29,25 @@ def f(G): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_conditional_gradient_itermax(): +def test_conditional_gradient_itermax(nx): n = 100 # nb samples mu_s = np.array([0, 0]) @@ -61,16 +71,27 @@ def f(G): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, + verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_generalized_conditional_gradient(): +def test_generalized_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -91,13 +112,23 @@ def f(G): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + reg1 = 1e-3 reg2 = 1e-1 + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) + Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1), atol=1e-05) - np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) + np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): @@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) -def test_line_search_armijo(): +def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 # Should not throw an exception and return None for alpha - alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) + alpha, a, b = ot.optim.line_search_armijo( + lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + ) + alpha_np, anp, bnp = ot.optim.line_search_armijo( + lambda x: 1, xk, pk, gfk, old_fval + ) + assert a == anp + assert b == bnp assert alpha is None # check line search armijo def f(x): - return np.sum((x - 5.0) ** 2) + return nx.sum((x - 5.0) ** 2) def grad(x): return 2 * (x - 5.0) - xk = np.array([[[-5.0, -5.0]]]) - pk = np.array([[[100.0, 100.0]]]) + xk = nx.from_numpy(np.array([[[-5.0, -5.0]]])) + pk = nx.from_numpy(np.array([[[100.0, 100.0]]])) gfk = grad(xk) old_fval = f(xk) @@ -132,10 +170,18 @@ def grad(x): np.testing.assert_allclose(alpha, 0.1) # check the case where the direction is not far enough - pk = np.array([[[3.0, 3.0]]]) + pk = nx.from_numpy(np.array([[[3.0, 3.0]]])) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) np.testing.assert_allclose(alpha, 1.0) - # check the case where the checking the wrong direction + # check the case where checking the wrong direction alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) assert alpha <= 0 + + # check the case where the point is not a vector + xk = nx.from_numpy(np.array(-5.0)) + pk = nx.from_numpy(np.array(100.0)) + gfk = grad(xk) + old_fval = f(xk) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) diff --git a/test/test_ot.py b/test/test_ot.py index 3e953dcbe..4dfc510e6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,7 @@ import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import get_backend_list, torch - -backend_list = get_backend_list() +from ot.backend import torch def test_emd_dimension_and_mass_mismatch(): @@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -@pytest.mark.parametrize('nx', backend_list) def test_emd_backends(nx): n_samples = 100 n_features = 2 @@ -59,7 +56,6 @@ def test_emd_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_emd2_backends(nx): n_samples = 100 n_features = 2 diff --git a/test/test_utils.py b/test/test_utils.py index 76b1faac6..60ad5d3de 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,17 +4,11 @@ # # License: MIT License -import pytest import ot import numpy as np import sys -from ot.backend import get_backend_list -backend_list = get_backend_list() - - -@pytest.mark.parametrize('nx', backend_list) def test_proj_simplex(nx): n = 10 rng = np.random.RandomState(0) @@ -119,7 +113,6 @@ def test_dist(): np.testing.assert_allclose(D, D3, atol=1e-14) -@ pytest.mark.parametrize('nx', backend_list) def test_dist_backends(nx): n = 100