-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Translation Invariant Sinkhorn for Unbalanced OT #676
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #676 +/- ##
==========================================
+ Coverage 96.99% 97.03% +0.04%
==========================================
Files 96 96
Lines 19117 19255 +138
==========================================
+ Hits 18542 18685 +143
+ Misses 575 570 -5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very ice thanks @clbonet ,
I have two comments about missing tests since you code handle infinite values it should be tested.
Hello @clbonet , Sorry but we did another merging and now there is a conflict with your PR (only minor things such as reference number in redame and release file). Could you please have a quick look (and also update reference numbers in the doc of functions) so that we can finally merge |
Types of changes
This PR aims to add the Translation Invariant Sinkhorn algorithm proposed in Faster Unbalanced Optimal Transport: Translation invariant Sinkhorn and 1-D Frank-Wolfe
to solve the Unbalanced OT problem.
sinkhorn_unbalanced_translation_invariant
inunbalanced/_sinkhorn.py
plot_conv_sinkhorn_ti.py
to demonstrate the convergence benefitsunbalanced/test_sinkhorn.py
Motivation and context / Related issue
This version of Sinkhorn converges faster than the classical Sinkhorn algorithm for UOT.
How has this been tested (if it applies)
The function can be called by using the method "sinkhorn_translation_invariant" in
sinkhorn_unbalanced
, so I added this method in the itertools of tests ofunbalanced/test_sinkhorn.py
.PR checklist