Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

empirical_sinkhorn_divergence doesn't have a grad_fn #393

Closed
gabrielsantosrv opened this issue Aug 9, 2022 · 5 comments · Fixed by #394
Closed

empirical_sinkhorn_divergence doesn't have a grad_fn #393

gabrielsantosrv opened this issue Aug 9, 2022 · 5 comments · Fixed by #394

Comments

@gabrielsantosrv
Copy link

Describe the bug

I am trying to use empirical_sinkhorn_divergence as a loss function in pytorch, but the returned tensor does not have a grad_fn, so the gradient can't be propagated.

Code sample

loss = ot.bregman.empirical_sinkhorn_divergence(source, target, 1)

Expected behavior

Return a tensor with a grad_fn.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.8.13
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version: 11.2
    • GPU models and configuration: Quadro RTX 8000

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.4.0-73-generic-x86_64-with-glibc2.17
Python 3.8.13 (default, Mar 28 2022, 11:38:47) 
[GCC 7.5.0]
NumPy 1.21.6
SciPy 1.8.1
POT 0.8.2
@rflamary
Copy link
Collaborator

Good catch! The backend for the sinkhorn divergence is not properly done so it goes to numpy and back losing gradient information (and a lot of interest ;)).

We will look into it in the meantime i suggest that you code a function that uses ot.dist to compute distances and ot.sinkhorn2 to return differentiable losses, an sums the three terms. It can be done in a few lines of code.

@gabrielsantosrv
Copy link
Author

gabrielsantosrv commented Aug 10, 2022

Since I am using a GPU to run my experiments, do you think it is better use ot.sinkhorn2 with parameter method="sinkhorn_log" to compute the sinkhorn divergence?

@rflamary
Copy link
Collaborator

yes it should be more numerically stable (slightly slower than traditional sinkhorn though)

@gabrielsantosrv
Copy link
Author

That's ok, thank you :D

@rflamary
Copy link
Collaborator

Hello @gabrielsantosrv

The function should now preserve the gradients on the master branch (we added a test to check so that it does not happens again.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants