Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize cost and gradient for ot.da.sinkhorn_l1l2_gl #507

Merged
merged 15 commits into from
Sep 21, 2023

Conversation

kachayev
Copy link
Collaborator

@kachayev kachayev commented Aug 20, 2023

Types of changes

Computation of the cost and gradient callbacks for generalized conditional gradient used by ot.da.sinkhorn_l1l2_gl are now vectorized.

One potential issue here is that norm function for backends were implemented as a direct computation of Euclidean 2-norm without leveraging built-in norms. I don't know if that was done on purpose (maybe there's a known problem with the one exposed by each library). I switched the call to make access to additional arguments, like axis and keepdims. Still could be implemented manually here, if necessary.

Motivation and context / Related issue

Overall sequential code is much harder to optimize compared to vectorized execution, specifically when working with GPUs. In this particular case, it's also easy to spot that labels_a == lab indices here are recomputed within nested for-loop per each call of the function (n_samples * n_labels * n_calls) despite being static. New implementation prepares matrix with proper indexing (only once) and then computes the result by leveraging normalization/summation operations on matrices defined by the corresponding backend.

Also, it seems like now we can have this function working normally on JAX backend (previous version didn't work because of immutability of JAX tensors).

How has this been tested (if it applies)

New unit test test_sinkhorn_l1l2_gl_cost_vectorized is defined for test_da.py test suite. This might be not the best way to do a unit testing but for this specific use case seems fair: it just contains 2 implementations (old and new one) to make sure those function return the same result.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov
Copy link

codecov bot commented Aug 20, 2023

Codecov Report

Merging #507 (24f860c) into master (526b72f) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #507      +/-   ##
==========================================
+ Coverage   96.07%   96.08%   +0.01%     
==========================================
  Files          65       65              
  Lines       13735    13784      +49     
==========================================
+ Hits        13196    13245      +49     
  Misses        539      539              

@kachayev
Copy link
Collaborator Author

Tests are okay, finally.

@rflamary rflamary merged commit 4cf4492 into PythonOT:master Sep 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants