-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Low rank sinkhorn algorithm #568
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #568 +/- ##
==========================================
+ Coverage 96.71% 96.73% +0.01%
==========================================
Files 75 77 +2
Lines 15651 15820 +169
==========================================
+ Hits 15137 15303 +166
- Misses 514 517 +3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @laudavid I have a few more comments after code review.
Also please add the paper to the references in README.md (with a number that should be the same as teh ref in the docstring of the function) at the PR in the new features in RELEASE.md, and your name in the contributors lists in CONTRIBUTORS.md
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some importnat changes to have proper gradients
ot/lowrank.py
Outdated
|
||
# First low rank decomposition of the cost matrix (A) | ||
M1 = nx.zeros((ns,(d+2)), type_as=X_s) | ||
norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)]) | |
norm_M1 = nx.sum(X_s**2, 1) |
ot/lowrank.py
Outdated
# First low rank decomposition of the cost matrix (A) | ||
M1 = nx.zeros((ns,(d+2)), type_as=X_s) | ||
norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)]) | ||
M1[:,0] = nx.from_numpy(norm_M1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
M1[:,0] = nx.from_numpy(norm_M1) | |
M1[:,0] = norm_M1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @laudavid ,
Thank you for this PR. Here are some remarks to improve the code.
Note also that you will have to fix some conflicts after that you pull the current master branch that got updated with other PRs.
ot/lowrank.py
Outdated
|
||
References | ||
---------- | ||
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to put the reference of the published paper.
Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
Also the number [63] should be updated after the merge, as new references have been included and [63] relates to another paper.
ot/lowrank.py
Outdated
|
||
def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): | ||
""" | ||
Compute low rank decomposition of a sqeuclidean cost matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squared euclidean distance matrix* in the documentation. Plus could you specify the parameters as in lowrank_sinkhorn
?
|
||
def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): | ||
""" | ||
Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you specify the parameters here too ?
ot/lowrank.py
Outdated
return M1, M2 | ||
|
||
|
||
def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to _LR_Dysktra
as it seems specific to lowrank_sinkhorn
ot/lowrank.py
Outdated
Regularization term >0 | ||
rank: int, default "auto" | ||
Nonnegative rank of the OT plan | ||
alpha: int, default "auto" (1e-10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is more conventional in POT API, to set a None
value as default such as: rank: int, optional. Default is None. Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.
Same template for alpha
.
NB: the rank actually has to be strictly positive.
ot/lowrank.py
Outdated
|
||
# Compute rank (see Section 3.1, def 1) | ||
r = rank | ||
if rank == "auto": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if rank is None:
.
ot/lowrank.py
Outdated
if rank == "auto": | ||
r = min(ns, nt) | ||
|
||
if alpha == "auto": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
ot/lowrank.py
Outdated
|
||
References | ||
---------- | ||
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update ref
Hello Rémi,
Here is a PR for the new version of lowrank_sinkhorn