-
-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Enzyme extension #377
Merged
Merged
Add Enzyme extension #377
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
bb6d623
Add Enzyme extension
wsmoses 9f8d18f
Add actual file
wsmoses 391b602
fix typo
ChrisRackauckas ce7ffc0
more v1.9
ChrisRackauckas a08386d
add a test for Enzyme rule correctness
ChrisRackauckas 9273a20
Extend
wsmoses 84c5196
add some batch tests
ChrisRackauckas bb93d68
Fix
wsmoses 9d19db2
Cache before LU in place
wsmoses f9b0784
simplify test
wsmoses 3b39753
fix multiple solve handling
ChrisRackauckas cbb5f1d
fix multiple solve handling
ChrisRackauckas 9630121
fix other algorithms
ChrisRackauckas b0d228d
getting very close
ChrisRackauckas c2ad2db
push batch test updates
ChrisRackauckas 54f0722
type stable
wsmoses e4f0785
fix mutated db
wsmoses d69af77
More caching
wsmoses be91ba2
Remove batch test
ChrisRackauckas 89e10df
Remove Krylov test for now
ChrisRackauckas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
module LinearSolveEnzymeExt | ||
|
||
using LinearSolve | ||
using LinearSolve.LinearAlgebra | ||
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) | ||
|
||
|
||
using Enzyme | ||
|
||
using EnzymeCore | ||
|
||
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
res = func.val(prob.val, alg.val; kwargs...) | ||
dres = if EnzymeRules.width(config) == 1 | ||
func.val(prob.dval, alg.val; kwargs...) | ||
else | ||
(func.val(dval, alg.val; kwargs...) for dval in prob.dval) | ||
end | ||
d_A = if EnzymeRules.width(config) == 1 | ||
dres.A | ||
else | ||
(dval.A for dval in dres) | ||
end | ||
d_b = if EnzymeRules.width(config) == 1 | ||
dres.b | ||
else | ||
(dval.b for dval in dres) | ||
end | ||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b)) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
d_A, d_b = cache | ||
|
||
if EnzymeRules.width(config) == 1 | ||
if d_A !== prob.dval.A | ||
prob.dval.A .+= d_A | ||
d_A .= 0 | ||
end | ||
if d_b !== prob.dval.b | ||
prob.dval.b .+= d_b | ||
d_b .= 0 | ||
end | ||
else | ||
for i in 1:EnzymeRules.width(config) | ||
if d_A !== prob.dval.A | ||
prob.dval.A[i] .+= d_A[i] | ||
d_A[i] .= 0 | ||
end | ||
if d_b !== prob.dval.b | ||
prob.dval.b[i] .+= d_b[i] | ||
d_b[i] .= 0 | ||
end | ||
end | ||
end | ||
|
||
return (nothing, nothing) | ||
end | ||
|
||
# y=inv(A) B | ||
# dA −= z y^T | ||
# dB += z, where z = inv(A^T) dy | ||
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
res = func.val(linsolve.val; kwargs...) | ||
|
||
dres = if EnzymeRules.width(config) == 1 | ||
deepcopy(res) | ||
else | ||
(deepcopy(res) for dval in linsolve.dval) | ||
end | ||
|
||
if EnzymeRules.width(config) == 1 | ||
dres.u .= 0 | ||
else | ||
for dr in dres | ||
dr.u .= 0 | ||
end | ||
end | ||
|
||
resvals = if EnzymeRules.width(config) == 1 | ||
dres.u | ||
else | ||
(dr.u for dr in dres) | ||
end | ||
|
||
cache = (res, resvals) | ||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
y, dys = cache | ||
_linsolve = linsolve.val | ||
|
||
@assert !(typeof(linsolve) <: Const) | ||
@assert !(typeof(linsolve) <: Active) | ||
|
||
if EnzymeRules.width(config) == 1 | ||
dys = (dys,) | ||
end | ||
|
||
dAs = if EnzymeRules.width(config) == 1 | ||
(linsolve.dval.A,) | ||
else | ||
(dval.A for dval in linsolve.dval) | ||
end | ||
|
||
dbs = if EnzymeRules.width(config) == 1 | ||
(linsolve.dval.b,) | ||
else | ||
(dval.b for dval in linsolve.dval) | ||
end | ||
|
||
for (dA, db, dy) in zip(dAs, dbs, dys) | ||
z = if _linsolve.cacheval isa Factorization | ||
_linsolve.cacheval' \ dy | ||
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization | ||
_linsolve.cacheval[1]' \ dy | ||
elseif linsolve.alg isa AbstractKrylovSubspaceMethod | ||
# Doesn't modify `A`, so it's safe to just reuse it | ||
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) | ||
solve(invprob; | ||
abstol = _linsolve.val.abstol, | ||
reltol = _linsolve.val.reltol, | ||
verbose = _linsolve.val.verbose) | ||
else | ||
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") | ||
end | ||
|
||
dA .-= z * transpose(y) | ||
db .+= z | ||
dy .= eltype(dy)(0) | ||
end | ||
|
||
return (nothing,) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
using Enzyme, FiniteDiff | ||
using LinearSolve, LinearAlgebra, Test | ||
|
||
n = 4 | ||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
|
||
function f(A, b1; alg = LUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
|
||
sol1 = solve(prob, alg) | ||
|
||
s1 = sol1.u | ||
norm(s1) | ||
end | ||
|
||
f(A, b1) # Uses BLAS | ||
|
||
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) | ||
|
||
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) | ||
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) | ||
|
||
@test dA ≈ dA2 | ||
@test db1 ≈ db12 | ||
|
||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
dA2 = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
db12 = zeros(n); | ||
|
||
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) | ||
|
||
dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) | ||
db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) | ||
|
||
@test_broken dA ≈ dA_2 | ||
@test_broken dA2 ≈ dA_2 | ||
@test_broken db1 ≈ db1_2 | ||
@test_broken db12 ≈ db1_2 | ||
|
||
function f(A, b1, b2; alg = LUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
|
||
cache = init(prob, alg) | ||
s1 = solve!(cache).u | ||
cache.b = b2 | ||
s2 = solve!(cache).u | ||
norm(s1 + s2) | ||
end | ||
|
||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
b2 = rand(n); | ||
db2 = zeros(n); | ||
|
||
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) | ||
|
||
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A)) | ||
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1)) | ||
db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2)) | ||
|
||
@test dA ≈ dA2 atol=5e-5 | ||
@test db1 ≈ db12 | ||
@test db2 ≈ db22 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still wrong, because linsolve still couldve been overwritten from forward to reverse. You need to cache it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay was just about to ask that, thanks. I think with that this may be completed. Though check the batch syntax in the test: the test still errors with BatchDuplicated and I'm not sure what to do there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the error log from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thats an easy one [which we sohuld fix]. You can't use an active return right now in batch mode (which also makes little sense here since you'd back propagate the same value to each). Just wrap that func in a closure that stores it to a vector or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, yeah the test was a bit dumb but just a quick sanity check 😓. Fixing that gives: