diff --git a/RELEASES.md b/RELEASES.md index a5fcbe15c..94c853b1a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,6 @@ # Releases + ## 0.8.2dev Development #### New features @@ -7,10 +8,12 @@ - Better list of related examples in quick start guide with `minigallery` (PR #334) - Add optional log-domain Sinkhorn implementation in WDA to support smaller values of the regularization parameter (PR #336) +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) #### Closed issues -- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338) +- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR + #338) ## 0.8.1.0 *December 2021* diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5da897d0b..2ff7c1feb 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -535,18 +535,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Parameters ---------- - measures_locations : list of N (k_i,d) numpy.ndarray + measures_locations : list of N (k_i,d) array-like The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) numpy.ndarray + measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) np.ndarray + X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter - b : (k,) np.ndarray + b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) np.ndarray + weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional @@ -564,7 +564,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Returns ------- - X : (k,d) np.ndarray + X : (k,d) array-like Support locations (on k atoms) of the barycenter @@ -577,15 +577,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None """ + nx = get_backend(*measures_locations,*measures_weights,X_init) + iter_count = 0 N = len(measures_locations) k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = np.ones((k,)) / k + b = nx.ones((k,),type_as=X_init) / k if weights is None: - weights = np.ones((N,)) / N + weights = nx.ones((N,),type_as=X_init) / N X = X_init @@ -596,15 +598,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = np.zeros((k, d)) + T_sum = nx.zeros((k, d),type_as=X_init) + - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, - weights.tolist()): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(T_sum - X)) + displacement_square_norm = nx.sum((T_sum - X)**2) if log: displacement_square_norms.append(displacement_square_norm) diff --git a/test/test_ot.py b/test/test_ot.py index 53edf4f65..e8e2d9753 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -302,6 +302,23 @@ def test_free_support_barycenter(): np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) +def test_free_support_barycenter_backends(nx): + + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + measures_weights = [np.array([1.]), np.array([1.])] + X_init = np.array([-12.]).reshape((1, 1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + measures_locations2 = [nx.from_numpy(x) for x in measures_locations] + measures_weights2 = [nx.from_numpy(x) for x in measures_weights] + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] diff --git a/test/test_utils.py b/test/test_utils.py index 6b476b2af..8b23c224e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -122,7 +122,7 @@ def test_dist(): 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', - 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule' + 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule' ] # those that support weights metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version