From 05ed99281b591cb58e59b32f193369da6b76dd3f Mon Sep 17 00:00:00 2001 From: ncassereau Date: Mon, 25 Oct 2021 16:22:51 +0200 Subject: [PATCH] Mistakes corrected --- ot/backend.py | 101 +++++++++++++++++++++++++------------------ ot/optim.py | 2 +- test/test_backend.py | 6 +++ test/test_optim.py | 18 +++++--- 4 files changed, 80 insertions(+), 47 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7f8e7cf6c..876b96a84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -123,7 +123,7 @@ def zeros(self, shape, type_as=None): r""" Creates a tensor full of zeros. - This function follow the api from :any:`numpy.zeros` + This function follows the api from :any:`numpy.zeros` See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html """ @@ -133,7 +133,7 @@ def ones(self, shape, type_as=None): r""" Creates a tensor full of ones. - This function follow the api from :any:`numpy.ones` + This function follows the api from :any:`numpy.ones` See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html """ @@ -143,7 +143,7 @@ def arange(self, stop, start=0, step=1, type_as=None): r""" Returns evenly spaced values within a given interval. - This function follow the api from :any:`numpy.arange` + This function follows the api from :any:`numpy.arange` See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html """ @@ -153,7 +153,7 @@ def full(self, shape, fill_value, type_as=None): r""" Creates a tensor with given shape, filled with given value. - This function follow the api from :any:`numpy.full` + This function follows the api from :any:`numpy.full` See: https://numpy.org/doc/stable/reference/generated/numpy.full.html """ @@ -163,7 +163,7 @@ def eye(self, N, M=None, type_as=None): r""" Creates the identity matrix of given size. - This function follow the api from :any:`numpy.eye` + This function follows the api from :any:`numpy.eye` See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html """ @@ -173,7 +173,7 @@ def sum(self, a, axis=None, keepdims=False): r""" Sums tensor elements over given dimensions. - This function follow the api from :any:`numpy.sum` + This function follows the api from :any:`numpy.sum` See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html """ @@ -183,7 +183,7 @@ def cumsum(self, a, axis=None): r""" Returns the cumulative sum of tensor elements over given dimensions. - This function follow the api from :any:`numpy.cumsum` + This function follows the api from :any:`numpy.cumsum` See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html """ @@ -193,7 +193,7 @@ def max(self, a, axis=None, keepdims=False): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amax` + This function follows the api from :any:`numpy.amax` See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html """ @@ -203,7 +203,7 @@ def min(self, a, axis=None, keepdims=False): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amin` + This function follows the api from :any:`numpy.amin` See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html """ @@ -213,7 +213,7 @@ def maximum(self, a, b): r""" Returns element-wise maximum of array elements. - This function follow the api from :any:`numpy.maximum` + This function follows the api from :any:`numpy.maximum` See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html """ @@ -223,7 +223,7 @@ def minimum(self, a, b): r""" Returns element-wise minimum of array elements. - This function follow the api from :any:`numpy.minimum` + This function follows the api from :any:`numpy.minimum` See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html """ @@ -233,7 +233,7 @@ def dot(self, a, b): r""" Returns the dot product of two tensors. - This function follow the api from :any:`numpy.dot` + This function follows the api from :any:`numpy.dot` See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html """ @@ -243,7 +243,7 @@ def abs(self, a): r""" Computes the absolute value element-wise. - This function follow the api from :any:`numpy.absolute` + This function follows the api from :any:`numpy.absolute` See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html """ @@ -253,7 +253,7 @@ def exp(self, a): r""" Computes the exponential value element-wise. - This function follow the api from :any:`numpy.exp` + This function follows the api from :any:`numpy.exp` See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html """ @@ -263,7 +263,7 @@ def log(self, a): r""" Computes the natural logarithm, element-wise. - This function follow the api from :any:`numpy.log` + This function follows the api from :any:`numpy.log` See: https://numpy.org/doc/stable/reference/generated/numpy.log.html """ @@ -273,7 +273,7 @@ def sqrt(self, a): r""" Returns the non-ngeative square root of a tensor, element-wise. - This function follow the api from :any:`numpy.sqrt` + This function follows the api from :any:`numpy.sqrt` See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html """ @@ -283,7 +283,7 @@ def power(self, a, exponents): r""" First tensor elements raised to powers from second tensor, element-wise. - This function follow the api from :any:`numpy.power` + This function follows the api from :any:`numpy.power` See: https://numpy.org/doc/stable/reference/generated/numpy.power.html """ @@ -293,7 +293,7 @@ def norm(self, a): r""" Computes the matrix frobenius norm. - This function follow the api from :any:`numpy.linalg.norm` + This function follows the api from :any:`numpy.linalg.norm` See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html """ @@ -303,7 +303,7 @@ def any(self, a): r""" Tests whether any tensor element along given dimensions evaluates to True. - This function follow the api from :any:`numpy.any` + This function follows the api from :any:`numpy.any` See: https://numpy.org/doc/stable/reference/generated/numpy.any.html """ @@ -313,7 +313,7 @@ def isnan(self, a): r""" Tests element-wise for NaN and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isnan` + This function follows the api from :any:`numpy.isnan` See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html """ @@ -323,7 +323,7 @@ def isinf(self, a): r""" Tests element-wise for positive or negative infinity and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isinf` + This function follows the api from :any:`numpy.isinf` See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html """ @@ -333,7 +333,7 @@ def einsum(self, subscripts, *operands): r""" Evaluates the Einstein summation convention on the operands. - This function follow the api from :any:`numpy.einsum` + This function follows the api from :any:`numpy.einsum` See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html """ @@ -343,7 +343,7 @@ def sort(self, a, axis=-1): r""" Returns a sorted copy of a tensor. - This function follow the api from :any:`numpy.sort` + This function follows the api from :any:`numpy.sort` See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html """ @@ -353,7 +353,7 @@ def argsort(self, a, axis=None): r""" Returns the indices that would sort a tensor. - This function follow the api from :any:`numpy.argsort` + This function follows the api from :any:`numpy.argsort` See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html """ @@ -363,7 +363,7 @@ def searchsorted(self, a, v, side='left'): r""" Finds indices where elements should be inserted to maintain order in given tensor. - This function follow the api from :any:`numpy.searchsorted` + This function follows the api from :any:`numpy.searchsorted` See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html """ @@ -373,7 +373,7 @@ def flip(self, a, axis=None): r""" Reverses the order of elements in a tensor along given dimensions. - This function follow the api from :any:`numpy.flip` + This function follows the api from :any:`numpy.flip` See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html """ @@ -383,7 +383,7 @@ def clip(self, a, a_min, a_max): """ Limits the values in a tensor. - This function follow the api from :any:`numpy.clip` + This function follows the api from :any:`numpy.clip` See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html """ @@ -393,7 +393,7 @@ def repeat(self, a, repeats, axis=None): r""" Repeats elements of a tensor. - This function follow the api from :any:`numpy.repeat` + This function follows the api from :any:`numpy.repeat` See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html """ @@ -403,7 +403,7 @@ def take_along_axis(self, arr, indices, axis): r""" Gathers elements of a tensor along given dimensions. - This function follow the api from :any:`numpy.take_along_axis` + This function follows the api from :any:`numpy.take_along_axis` See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html """ @@ -413,7 +413,7 @@ def concatenate(self, arrays, axis=0): r""" Joins a sequence of tensors along an existing dimension. - This function follow the api from :any:`numpy.concatenate` + This function follows the api from :any:`numpy.concatenate` See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html """ @@ -423,7 +423,7 @@ def zero_pad(self, a, pad_width): r""" Pads a tensor. - This function follow the api from :any:`numpy.pad` + This function follows the api from :any:`numpy.pad` See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ @@ -433,7 +433,7 @@ def argmax(self, a, axis=None): r""" Returns the indices of the maximum values of a tensor along given dimensions. - This function follow the api from :any:`numpy.argmax` + This function follows the api from :any:`numpy.argmax` See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html """ @@ -443,7 +443,7 @@ def mean(self, a, axis=None): r""" Computes the arithmetic mean of a tensor along given dimensions. - This function follow the api from :any:`numpy.mean` + This function follows the api from :any:`numpy.mean` See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html """ @@ -453,7 +453,7 @@ def std(self, a, axis=None): r""" Computes the standard deviation of a tensor along given dimensions. - This function follow the api from :any:`numpy.std` + This function follows the api from :any:`numpy.std` See: https://numpy.org/doc/stable/reference/generated/numpy.std.html """ @@ -463,7 +463,7 @@ def linspace(self, start, stop, num): r""" Returns a specified number of evenly spaced values over a given interval. - This function follow the api from :any:`numpy.linspace` + This function follows the api from :any:`numpy.linspace` See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html """ @@ -473,7 +473,7 @@ def meshgrid(self, a, b): r""" Returns coordinate matrices from coordinate vectors (Numpy convention). - This function follow the api from :any:`numpy.meshgrid` + This function follows the api from :any:`numpy.meshgrid` See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html """ @@ -483,7 +483,7 @@ def diag(self, a, k=0): r""" Extracts or constructs a diagonal tensor. - This function follow the api from :any:`numpy.diag` + This function follows the api from :any:`numpy.diag` See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html """ @@ -493,7 +493,7 @@ def unique(self, a): r""" Finds unique elements of given tensor. - This function follow the api from :any:`numpy.unique` + This function follows the api from :any:`numpy.unique` See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html """ @@ -503,7 +503,7 @@ def logsumexp(self, a, axis=None): r""" Computes the log of the sum of exponentials of input elements. - This function follow the api from :any:`scipy.special.logsumexp` + This function follows the api from :any:`scipy.special.logsumexp` See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html """ @@ -513,7 +513,7 @@ def stack(self, arrays, axis=0): r""" Joins a sequence of tensors along a new dimension. - This function follow the api from :any:`numpy.stack` + This function follows the api from :any:`numpy.stack` See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html """ @@ -523,12 +523,22 @@ def outer(self, a, b): r""" Computes the outer product between two vectors. - This function follow the api from :any:`numpy.outer` + This function follows the api from :any:`numpy.outer` See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html """ raise NotImplementedError() + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -699,6 +709,9 @@ def logsumexp(self, a, axis=None): def stack(self, arrays, axis=0): return np.stack(arrays, axis) + def reshape(self, a, shape): + return np.reshape(a, shape) + class JaxBackend(Backend): """ @@ -873,6 +886,9 @@ def logsumexp(self, a, axis=None): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) + def reshape(self, a, shape): + return jnp.reshape(a, shape) + class TorchBackend(Backend): """ @@ -1110,3 +1126,6 @@ def logsumexp(self, a, axis=None): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) diff --git a/ot/optim.py b/ot/optim.py index 8e4633b88..6822e4eba 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -57,7 +57,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, """ xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk, gfk) + nx = get_backend(xk, pk) if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) diff --git a/test/test_backend.py b/test/test_backend.py index 8ab4ddcec..58532824f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -205,6 +205,8 @@ def test_empty_backend(): nx.logsumexp(M) with pytest.raises(NotImplementedError): nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) def test_func_backends(nx): @@ -432,6 +434,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('stack') + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_optim.py b/test/test_optim.py index 0b9bbe72f..4efd9b161 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -155,13 +155,13 @@ def test_line_search_armijo(nx): # check line search armijo def f(x): - return np.sum((x - 5.0) ** 2) + return nx.sum((x - 5.0) ** 2) def grad(x): return 2 * (x - 5.0) - xk = np.array([[[-5.0, -5.0]]]) - pk = np.array([[[100.0, 100.0]]]) + xk = nx.from_numpy(np.array([[[-5.0, -5.0]]])) + pk = nx.from_numpy(np.array([[[100.0, 100.0]]])) gfk = grad(xk) old_fval = f(xk) @@ -170,10 +170,18 @@ def grad(x): np.testing.assert_allclose(alpha, 0.1) # check the case where the direction is not far enough - pk = np.array([[[3.0, 3.0]]]) + pk = nx.from_numpy(np.array([[[3.0, 3.0]]])) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) np.testing.assert_allclose(alpha, 1.0) - # check the case where the checking the wrong direction + # check the case where checking the wrong direction alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) assert alpha <= 0 + + # check the case where the point is not a vector + xk = nx.from_numpy(np.array(-5.0)) + pk = nx.from_numpy(np.array(100.0)) + gfk = grad(xk) + old_fval = f(xk) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1)