-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement QR pullback #306
base: main
Are you sure you want to change the base?
Conversation
if size(F.R, 2) != size(F.R, 1) | ||
throw(ArgumentError("Pullback for QR decomposition is only supported for m × n matrices with m >= n")) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I have not found a reference for that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Section 3.2 of https://arxiv.org/pdf/2009.10071.pdf gives an rrule
for qr
for wide matrices (m<n
). I haven't tested it, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into adding that.
|
||
# Explicitely convert to Matrix since FiniteDifferences seem to | ||
# be broken for LinearAlgebra.QRCompactWYQ (infinite to_vec | ||
# recursion) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can an issue openned for this an linked here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done here JuliaDiff/FiniteDifferences.jl#114
|
||
F, dX_pullback = rrule(qr, X) | ||
for p in [:Q, :R] | ||
Y, dF_pullback = rrule(getproperty, F, p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does rrule_test
not work on this?
If it doesn't an issue needs to be openned on ChainRulesTestUtils.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not work on getproperty
because of the failure of FiniteDifferences.jl
mentionned elsewhere on this PR.
It does not work directly on qr
either, because the object returned by qr
can not be collect
ed. See JuliaDiff/ChainRulesTestUtils.jl#74
@@ -138,4 +138,40 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo | |||
@test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true) | |||
end | |||
end | |||
@testset "qr" begin | |||
@testset "the thing" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just reused the same structure as the test for cholesky. I can give it a more explicit name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do.
C = Composite{T} | ||
∂F = if x === :Q | ||
C(Q=Ȳ,) | ||
elseif x === :R | ||
C(R=Ȳ,) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth defining C
?
C = Composite{T} | |
∂F = if x === :Q | |
C(Q=Ȳ,) | |
elseif x === :R | |
C(R=Ȳ,) | |
end | |
∂F = if x === :Q | |
Composite{T}(Q=Ȳ,) | |
elseif x === :R | |
Composite{T}(R=Ȳ,) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it analog to the svd
case. I think I could go even simpler with
∂F = Composite{T}(; x => Ȳ)
Should I port that kind of cleanup to svd
and cholesky
as well?
src/rulesets/LinearAlgebra/utils.jl
Outdated
@@ -27,3 +27,17 @@ function _eyesubx!(X::AbstractMatrix) | |||
end | |||
|
|||
_extract_imag(x) = complex(0, imag(x)) | |||
|
|||
# Lower triangle of X - X' overwrite X if possible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to always overwrite?
if x === :Q | ||
# Return thing Q for consistency | ||
n = size(F.R, 1) | ||
return F.Q[:, 1:n], getproperty_qr_pullback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this on to do to the primal value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was because the anything else does not contribute, and the tests used the result of this to infer dimensions. However after pondering on it I figured out it is probably better to let the getproperty
pullback do just the obvious and have the qr
pullback make sure to only take in account what is relevant and make sure the dimensions match.
return NO_FIELDS, ∂F, DoesNotExist() | ||
end | ||
if x === :Q | ||
# Return thing Q for consistency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was very confused
# Return thing Q for consistency | |
# Return thin Q for consistency |
I wonder if the QR rule implied by Seeger et al in https://arxiv.org/pdf/1710.08717.pdf is more performant than the one in Walter and Lehmann? (they actually define an LQ rule, but the same approach produces a QR rule). The below reimplementation of this PR's function qr_rev2(QR_::ChainRules.QR_TYPE, Q̄, R̄)
Q, R = QR_
Q = Matrix(Q)
Q̄ = Q̄ isa Zero ? Q̄ : @view Q̄[:, axes(Q, 2)]
V = R̄*R' - Q'*Q̄
Ā = (Q̄ + Q * Hermitian(V)) / R'
return Ā
end
julia> A = randn(4, 4);
julia> F = qr(A);
julia> ΔF = Composite{typeof(F)}(Q = randn(eltype(F.Q), size(Matrix(F.Q))), R = randn(eltype(F.R), size(F.R)));
julia> @btime ChainRules.qr_rev($F, $(ΔF.Q), $(ΔF.R));
16.374 μs (58 allocations: 7.11 KiB)
julia> @btime qr_rev2($F, $(ΔF.Q), $(ΔF.R));
4.323 μs (12 allocations: 2.27 KiB)
julia> ChainRules.qr_rev(F, ΔF.Q, ΔF.R) ≈ qr_rev2(F, ΔF.Q, ΔF.R)
true
julia> A = randn(10, 4);
julia> F = qr(A);
julia> ΔF = Composite{typeof(F)}(Q = randn(eltype(F.Q), size(Matrix(F.Q))), R = randn(eltype(F.R), size(F.R)));
julia> @btime ChainRules.qr_rev($F, $(ΔF.Q), $(ΔF.R));
56.299 μs (134 allocations: 20.86 KiB)
julia> @btime qr_rev2($F, $(ΔF.Q), $(ΔF.R));
5.112 μs (12 allocations: 3.39 KiB)
julia> ChainRules.qr_rev(F, ΔF.Q, ΔF.R) ≈ qr_rev2(F, ΔF.Q, ΔF.R)
true Or is the rule you have implemented expected to be more numerically stable? |
I have no idea. The two references we are using do not directly compare each other, and I do not know how to determine this myself. |
The article I linked in #306 (comment) (https://arxiv.org/pdf/2009.10071.pdf), which covers wide and tall matrices as well, also uses the simpler rule from the Seeger et al paper. While I didn't implement their rules for wide and tall matrices, I ended up using a similar approach for the LU decomposition of wide and tall matrices in #354. For these reasons, I'm thinking the Seeger approach is preferable. |
I finally had time to come back to this, and it has been kind of a nightmare, because QR decompositions are represented in a weird way that do not play nicely with the tests and the comparison with FiniteDifferences. After quite a lot of experimentations, I gave up on trying to make everything work with the default type returned by I hope this is sufficient. Otherwise I must admit I am out of idea about what should be done to test I am aware of #469 that has a somewhat different approach. I am not currently sure which is better. Alos I implemented the algorithm suggested by @sethaxen. So provided my way of testing is okay, this should be ready. |
What I like about this approach is that it completely sidesteps much of the complexity of the objects returned by the What I don't like is that the object being returned by the A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
y = Q*w + Q*v
@assert size(y) == (10,) This works because |
I tried the following using ChainRules: rrule, ExplicitQR
using LinearAlgebra
A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
F, F_pullback = rrule(qr, A)
Q, Q_pullback = rrule(getproperty, F, :Q)
y1, y1_pullback = rrule(*, Q, v)
y2, y2_pullback = rrule(*, Q, w)
ȳ1 = rand(10)
_, Q̄1 = y1_pullback(ȳ1)
_, F̄1 = Q_pullback(Q̄1)
_, Ā1 = F_pullback(F̄1)
ȳ2 = rand(10)
_, Q̄2 = y2_pullback(ȳ2)
_, F̄2 = Q_pullback(Q̄2)
_, Ā2 = F_pullback(F̄2) and everything seems to be fine (i.e. nothing error, I haven't tested correctness). Is this the correct way to test your point? Whether this PR or #469 is used, adding proper test for quirks of Note that the As far as I understand, [*] The most common error I had was:
Using a custom struct with explicit fields |
Implement the pullback for the QR decomposition, following:
Walter and Lehmann, 2018, Algorithmic Differentiation of Linear Algebra Functions with Application in Optimum Experimental Design