-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
empirical_sinkhorn_divergence doesn't have a grad_fn #393
Comments
Good catch! The backend for the sinkhorn divergence is not properly done so it goes to numpy and back losing gradient information (and a lot of interest ;)). We will look into it in the meantime i suggest that you code a function that uses ot.dist to compute distances and ot.sinkhorn2 to return differentiable losses, an sums the three terms. It can be done in a few lines of code. |
Since I am using a GPU to run my experiments, do you think it is better use |
yes it should be more numerically stable (slightly slower than traditional sinkhorn though) |
That's ok, thank you :D |
Hello @gabrielsantosrv The function should now preserve the gradients on the master branch (we added a test to check so that it does not happens again.) |
Describe the bug
I am trying to use empirical_sinkhorn_divergence as a loss function in pytorch, but the returned tensor does not have a grad_fn, so the gradient can't be propagated.
Code sample
loss = ot.bregman.empirical_sinkhorn_divergence(source, target, 1)
Expected behavior
Return a tensor with a grad_fn.
Environment (please complete the following information):
pip
,conda
): pipOutput of the following code snippet:
The text was updated successfully, but these errors were encountered: