-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CoLA integration #370
CoLA integration #370
Conversation
return self.scale.diagonal() | ||
return cola.diag(self.scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check: This probably has edge case behaviour as diag
switches between diagonal and (dense) diagonal matrix, while diagonal
is strictly to a diagonal array.
# TODO: Once this functionality is supported in CoLA, remove this. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to open issue.
Hey @mfinzi, would be grateful if you could glance over this integration! For context its pretty minimal, and only aims to remove our ancient linear operators and revamp them to CoLA. Undoubtedly there may be more efficient things we could do e.g., within our Nice work on wilson-labs/cola/#33, but perhaps really what I more would like is access to the (lower) Cholesky root itself like I have done in https://github.com/JaxGaussianProcesses/GPJax/blob/2c9ebce5a110b73a54ee9a38f408bd1950912026/gpjax/lower_cholesky.py. I thought the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all looks good to me! Very nice :D
Resolve rebase from branch.
Type of changes
Opening draft PR to get 🪩 rolling on migrating GPJax linops to CoLA 🥤.
Major outstanding work left to do are (1) API considerations, (2) comprehensive testing.
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing. (Yes but there might be an issue with coverage).Description
The
gpjax.linops
module has been removed. All linear operators in GPJax have been replaced with their analogue in CoLA e.g.,:gpjax.linops.DenseLinearOperator
->cola.ops.Dense
(wrapped aroundcola.PSD
where appropriate).gpjax.linops.DiagonalLinearOperator
->cola.ops.Diagonal
(wrapped aroundcola.PSD
where appropriate).With minimal modification to the code, such that tests pass locally.
Plum dispatch is dropped as a direct dependancy and we use
singedispatch
forcitations
to avoid clashes with CoLA.Outstanding issues **[edit:] Have opened issues for these!
solve
androot
operations on linear operators.cross_covaraince
be a LinearOperator rather than a dense array (this would match the signature ofgram
). This is beneficial in sparse situations.