Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Low rank sinkhorn algorithm #568

Merged
merged 40 commits into from
Dec 5, 2023
Merged

Conversation

laudavid
Copy link
Contributor

@laudavid laudavid commented Nov 9, 2023

Hello Rémi,

Here is a PR for the new version of lowrank_sinkhorn

  • The function returns now: value, value_linear, lazy_plan (with LazyTensor), Q, R and g
  • I also added some test functions for the lowrank.py file (WIP)

@rflamary rflamary changed the title Low rank sinkhorn algorithm v2 [WIP] Low rank sinkhorn algorithm Nov 9, 2023
Copy link

codecov bot commented Nov 17, 2023

Codecov Report

Merging #568 (218c54a) into master (659cde8) will increase coverage by 0.01%.
The diff coverage is 98.23%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #568      +/-   ##
==========================================
+ Coverage   96.71%   96.73%   +0.01%     
==========================================
  Files          75       77       +2     
  Lines       15651    15820     +169     
==========================================
+ Hits        15137    15303     +166     
- Misses        514      517       +3     

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @laudavid I have a few more comments after code review.

Also please add the paper to the references in README.md (with a number that should be the same as teh ref in the docstring of the function) at the PR in the new features in RELEASE.md, and your name in the contributors lists in CONTRIBUTORS.md

ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
ot/lowrank.py Outdated Show resolved Hide resolved
test/test_lowrank.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some importnat changes to have proper gradients

ot/lowrank.py Outdated

# First low rank decomposition of the cost matrix (A)
M1 = nx.zeros((ns,(d+2)), type_as=X_s)
norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)])
norm_M1 = nx.sum(X_s**2, 1)

ot/lowrank.py Outdated
# First low rank decomposition of the cost matrix (A)
M1 = nx.zeros((ns,(d+2)), type_as=X_s)
norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)])
M1[:,0] = nx.from_numpy(norm_M1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
M1[:,0] = nx.from_numpy(norm_M1)
M1[:,0] = norm_M1

ot/lowrank.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@cedricvincentcuaz cedricvincentcuaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @laudavid ,

Thank you for this PR. Here are some remarks to improve the code.
Note also that you will have to fix some conflicts after that you pull the current master branch that got updated with other PRs.

ot/lowrank.py Outdated

References
----------
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to put the reference of the published paper.

Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.

Also the number [63] should be updated after the merge, as new references have been included and [63] relates to another paper.

ot/lowrank.py Outdated

def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
"""
Compute low rank decomposition of a sqeuclidean cost matrix.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squared euclidean distance matrix* in the documentation. Plus could you specify the parameters as in lowrank_sinkhorn ?


def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None):
"""
Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify the parameters here too ?

ot/lowrank.py Outdated
return M1, M2


def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to _LR_Dysktra as it seems specific to lowrank_sinkhorn

ot/lowrank.py Outdated
Regularization term >0
rank: int, default "auto"
Nonnegative rank of the OT plan
alpha: int, default "auto" (1e-10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more conventional in POT API, to set a None value as default such as: rank: int, optional. Default is None. Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.

Same template for alpha.

NB: the rank actually has to be strictly positive.

ot/lowrank.py Outdated

# Compute rank (see Section 3.1, def 1)
r = rank
if rank == "auto":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if rank is None:.

ot/lowrank.py Outdated
if rank == "auto":
r = min(ns, nt)

if alpha == "auto":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

ot/lowrank.py Outdated

References
----------
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update ref

@rflamary rflamary changed the title [WIP] Low rank sinkhorn algorithm [MRG] Low rank sinkhorn algorithm Dec 5, 2023
@rflamary rflamary merged commit 0024d07 into PythonOT:master Dec 5, 2023
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants