Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme extension #377

Merged
merged 20 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Expand All @@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down Expand Up @@ -78,6 +81,8 @@ julia = "1.6"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -95,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
137 changes: 137 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
module LinearSolveEnzymeExt

using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)

Check warning on line 15 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L12-L15

Added lines #L12 - L15 were not covered by tests
else
(func.val(dval, alg.val; kwargs...) for dval in prob.dval)

Check warning on line 17 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L17

Added line #L17 was not covered by tests
end
d_A = if EnzymeRules.width(config) == 1
dres.A

Check warning on line 20 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L19-L20

Added lines #L19 - L20 were not covered by tests
else
(dval.A for dval in dres)

Check warning on line 22 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L22

Added line #L22 was not covered by tests
end
d_b = if EnzymeRules.width(config) == 1
dres.b

Check warning on line 25 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
else
(dval.b for dval in dres)

Check warning on line 27 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L27

Added line #L27 was not covered by tests
end
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))

Check warning on line 29 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L29

Added line #L29 was not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b = cache

Check warning on line 33 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L32-L33

Added lines #L32 - L33 were not covered by tests

if EnzymeRules.width(config) == 1
if d_A !== prob.dval.A
prob.dval.A .+= d_A
d_A .= 0

Check warning on line 38 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L35-L38

Added lines #L35 - L38 were not covered by tests
end
if d_b !== prob.dval.b
prob.dval.b .+= d_b
d_b .= 0

Check warning on line 42 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L40-L42

Added lines #L40 - L42 were not covered by tests
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob.dval.A
prob.dval.A[i] .+= d_A[i]
d_A[i] .= 0

Check warning on line 48 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L45-L48

Added lines #L45 - L48 were not covered by tests
end
if d_b !== prob.dval.b
prob.dval.b[i] .+= d_b[i]
d_b[i] .= 0

Check warning on line 52 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L50-L52

Added lines #L50 - L52 were not covered by tests
end
end

Check warning on line 54 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L54

Added line #L54 was not covered by tests
end

return (nothing, nothing)

Check warning on line 57 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L57

Added line #L57 was not covered by tests
end

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)

Check warning on line 64 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L63-L64

Added lines #L63 - L64 were not covered by tests

dres = if EnzymeRules.width(config) == 1
deepcopy(res)

Check warning on line 67 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
else
(deepcopy(res) for dval in linsolve.dval)

Check warning on line 69 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L69

Added line #L69 was not covered by tests
end

if EnzymeRules.width(config) == 1
dres.u .= 0

Check warning on line 73 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L72-L73

Added lines #L72 - L73 were not covered by tests
else
for dr in dres
dr.u .= 0
end

Check warning on line 77 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L75-L77

Added lines #L75 - L77 were not covered by tests
end

resvals = if EnzymeRules.width(config) == 1
dres.u

Check warning on line 81 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L80-L81

Added lines #L80 - L81 were not covered by tests
else
(dr.u for dr in dres)

Check warning on line 83 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L83

Added line #L83 was not covered by tests
end

cache = (res, resvals)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 87 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys = cache
_linsolve = linsolve.val

Check warning on line 92 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L90-L92

Added lines #L90 - L92 were not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still wrong, because linsolve still couldve been overwritten from forward to reverse. You need to cache it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay was just about to ask that, thanks. I think with that this may be completed. Though check the batch syntax in the test: the test still errors with BatchDuplicated and I'm not sure what to do there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error log from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ERROR: TypeError: in ccall argument 6, expected Tuple{Float64, Float64}, got a value of type Float64
Stacktrace:
 [1] macro expansion
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9774 [inlined]
 [2] enzyme_call
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9452 [inlined]
 [3] CombinedAdjointThunk
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9415 [inlined]
 [4] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:213 [inlined]
 [5] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:236 [inlined]
 [6] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::BatchDuplicated{Matrix{Float64}, 2}, ::BatchDuplicated{Vector{Float64}, 2})
   @ Enzyme C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:222
 [7] top-level scope
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thats an easy one [which we sohuld fix]. You can't use an active return right now in batch mode (which also makes little sense here since you'd back propagate the same value to each). Just wrap that func in a closure that stores it to a vector or something

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, yeah the test was a bit dumb but just a quick sanity check 😓. Fixing that gives:

ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.ConfigWidth{2, true, true, (false, false, false)}, Const{typeof(init)}, Type{BatchDuplicated{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, 2}}, BatchDuplicated{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, 2}, Const{LUFactorization{RowMaximum}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Tuple{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}}, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, Tuple{Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#3#6"}, Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#4#7"}}}
Stacktrace:
 [1] #solve#5
   @ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:193
 [2] solve
   @ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:190
 [3] #fbatch#207
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:39
 [4] fbatch
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36
 [5] fbatch
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:0

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:3066


@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

Check warning on line 95 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L94-L95

Added lines #L94 - L95 were not covered by tests

if EnzymeRules.width(config) == 1
dys = (dys,)

Check warning on line 98 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L97-L98

Added lines #L97 - L98 were not covered by tests
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)

Check warning on line 102 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L101-L102

Added lines #L101 - L102 were not covered by tests
else
(dval.A for dval in linsolve.dval)

Check warning on line 104 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L104

Added line #L104 was not covered by tests
end

dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)

Check warning on line 108 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
else
(dval.b for dval in linsolve.dval)

Check warning on line 110 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L110

Added line #L110 was not covered by tests
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif linsolve.alg isa AbstractKrylovSubspaceMethod

Check warning on line 118 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L113-L118

Added lines #L113 - L118 were not covered by tests
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;

Check warning on line 121 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L120-L121

Added lines #L120 - L121 were not covered by tests
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")

Check warning on line 126 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L126

Added line #L126 was not covered by tests
end

dA .-= z * transpose(y)
db .+= z
dy .= eltype(dy)(0)
end

Check warning on line 132 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L129-L132

Added lines #L129 - L132 were not covered by tests

return (nothing,)

Check warning on line 134 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L134

Added line #L134 was not covered by tests
end

end
3 changes: 3 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/LinearSolveEnzymeExt.jl")

Check warning on line 19 in src/init.jl

View check run for this annotation

Codecov / codecov/patch

src/init.jl#L19

Added line #L19 was not covered by tests
end
end
end
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ end
end
end

test_algs = if VERISON >= v"1.9"
test_algs = if VERSION >= v"1.9"
(LUFactorization(),
QRFactorization(),
SVDFactorization(),
Expand Down
71 changes: 71 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using Enzyme, FiniteDiff
using LinearSolve, LinearAlgebra, Test

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end

f(A, b1) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))

@test dA ≈ dA2
@test db1 ≈ db12

A = rand(n, n);
dA = zeros(n, n);
dA2 = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))

dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))

@test_broken dA ≈ dA_2
@test_broken dA2 ≈ dA_2
@test_broken db1 ≈ db1_2
@test_broken db12 ≈ db1_2

function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)

cache = init(prob, alg)
s1 = solve!(cache).u
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1))
db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2))

@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension)

if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests" include("basictests.jl")
@time @safetestset "Re-solve" include("resolve.jl")
VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl")
@time @safetestset "Zero Initialization Tests" include("zeroinittests.jl")
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Traits" include("traits.jl")
end

Expand Down
Loading