Skip to content

Commit

Permalink
[MRG] Debug convolutional methods that compute barycenters to work wi…
Browse files Browse the repository at this point in the history
…th different devices. (#533)

* [FEAT] Add the parameter 'type_as' to the backends

* [TEST] add tests for the 'type_as' backends

* [DEBUG] Debug dtype in pytorch

* [FIX] Add type_as every time linspace is called

* [TEST] Add test for the convolutional_barycenter2d algorithms (they are the only ones use linspace)

* [DEBUG] PEP 8

* [DOC] Add the new changes to RELEASES.md

* [REFACTOR] Minor refactor that checks the GPU on the last line

* [DEBUG] pep8

* [REFACTOR] Add a function to generalize the creation of random images

* [REFACTOR] Mantain th same style as before

* Update gitignore

---------

Co-authored-by: Francisco Muñoz <fmunoz@ug.uchile.cl>
  • Loading branch information
framunoz and Francisco Muñoz authored Oct 18, 2023
1 parent ffdd1cf commit 8a4a5a6
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 65 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ celerybeat-schedule
# virtualenv
venv/
ENV/
.venv/

# Spyder project settings
.spyderproject
Expand All @@ -120,4 +121,4 @@ debug
.vscode

# pytest cahche
.pytest_cache
.pytest_cache
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
+ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
50 changes: 34 additions & 16 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@
#
# License: MIT License

import numpy as np
import os
import scipy
import scipy.linalg
from scipy.sparse import issparse, coo_matrix, csr_matrix
import scipy.special as special
import time
import warnings

import numpy as np
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import coo_matrix, csr_matrix, issparse

DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
Expand Down Expand Up @@ -650,7 +650,7 @@ def std(self, a, axis=None):
"""
raise NotImplementedError()

def linspace(self, start, stop, num):
def linspace(self, start, stop, num, type_as=None):
r"""
Returns a specified number of evenly spaced values over a given interval.
Expand Down Expand Up @@ -1208,8 +1208,11 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return np.std(a, axis=axis)

def linspace(self, start, stop, num):
return np.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return np.linspace(start, stop, num)
else:
return np.linspace(start, stop, num, dtype=type_as.dtype)

def meshgrid(self, a, b):
return np.meshgrid(a, b)
Expand Down Expand Up @@ -1579,8 +1582,11 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return jnp.std(a, axis=axis)

def linspace(self, start, stop, num):
return jnp.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return jnp.linspace(start, stop, num)
else:
return self._change_device(jnp.linspace(start, stop, num, dtype=type_as.dtype), type_as)

def meshgrid(self, a, b):
return jnp.meshgrid(a, b)
Expand Down Expand Up @@ -1986,6 +1992,7 @@ def concatenate(self, arrays, axis=0):

def zero_pad(self, a, pad_width, value=0):
from torch.nn.functional import pad

# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
Expand All @@ -2006,6 +2013,7 @@ def mean(self, a, axis=None):

def median(self, a, axis=None):
from packaging import version

# Since version 1.11.0, interpolation is available
if version.parse(torch.__version__) >= version.parse("1.11.0"):
if axis is not None:
Expand All @@ -2026,8 +2034,11 @@ def std(self, a, axis=None):
else:
return torch.std(a, unbiased=False)

def linspace(self, start, stop, num):
return torch.linspace(start, stop, num, dtype=torch.float64)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return torch.linspace(start, stop, num)
else:
return torch.linspace(start, stop, num, dtype=type_as.dtype, device=type_as.device)

def meshgrid(self, a, b):
try:
Expand Down Expand Up @@ -2427,8 +2438,12 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return cp.std(a, axis=axis)

def linspace(self, start, stop, num):
return cp.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return cp.linspace(start, stop, num)
else:
with cp.cuda.Device(type_as.device):
return cp.linspace(start, stop, num, dtype=type_as.dtype)

def meshgrid(self, a, b):
return cp.meshgrid(a, b)
Expand Down Expand Up @@ -2834,8 +2849,11 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return tnp.std(a, axis=axis)

def linspace(self, start, stop, num):
return tnp.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return tnp.linspace(start, stop, num)
else:
return tnp.linspace(start, stop, num, dtype=type_as.dtype)

def meshgrid(self, a, b):
return tnp.meshgrid(a, b)
Expand Down
19 changes: 10 additions & 9 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import numpy as np
from scipy.optimize import fmin_l_bfgs_b

from ot.utils import unif, dist, list_to_array
from ot.utils import dist, list_to_array, unif

from .backend import get_backend


Expand Down Expand Up @@ -2217,11 +2218,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1])
t = nx.linspace(0, 1, A.shape[1], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, A.shape[2])
t = nx.linspace(0, 1, A.shape[2], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

Expand Down Expand Up @@ -2295,11 +2296,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height)
t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

Expand Down Expand Up @@ -2452,11 +2453,11 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, height)
t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

Expand Down Expand Up @@ -2532,11 +2533,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height)
t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

Expand Down
12 changes: 5 additions & 7 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
#
# License: MIT License

import ot
import ot.backend
from ot.backend import torch, jax, tf

import pytest

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal_nulp

from ot.backend import get_backend, get_backend_list, to_numpy
import ot
import ot.backend
from ot.backend import get_backend, get_backend_list, jax, tf, to_numpy, torch


def test_get_backend_list():
Expand Down Expand Up @@ -507,6 +504,7 @@ def test_func_backends(nx):
lst_name.append('std')

A = nx.linspace(0, 1, 50)
A = nx.linspace(0, 1, 50, type_as=Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('linspace')

Expand Down
Loading

0 comments on commit 8a4a5a6

Please sign in to comment.