diff --git a/Project.toml b/Project.toml index 2cabbc4f3..a9c2f44b2 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" @@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" +LinearSolveEnzymeExt = "Enzyme" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -78,6 +81,8 @@ julia = "1.6" [extras] BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -95,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"] diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl new file mode 100644 index 000000000..607189e10 --- /dev/null +++ b/ext/LinearSolveEnzymeExt.jl @@ -0,0 +1,166 @@ +module LinearSolveEnzymeExt + +using LinearSolve +using LinearSolve.LinearAlgebra +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + + +using Enzyme + +using EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + res = func.val(prob.val, alg.val; kwargs...) + dres = if EnzymeRules.width(config) == 1 + func.val(prob.dval, alg.val; kwargs...) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + func.val(prob.dval[i], alg.val; kwargs...) + end + end + d_A = if EnzymeRules.width(config) == 1 + dres.A + else + (dval.A for dval in dres) + end + d_b = if EnzymeRules.width(config) == 1 + dres.b + else + (dval.b for dval in dres) + end + + + prob_d_A = if EnzymeRules.width(config) == 1 + prob.dval.A + else + (dval.A for dval in prob.dval) + end + prob_d_b = if EnzymeRules.width(config) == 1 + prob.dval.b + else + (dval.b for dval in prob.dval) + end + + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + d_A, d_b, prob_d_A, prob_d_b = cache + + if EnzymeRules.width(config) == 1 + if d_A !== prob_d_A + prob_d_A .+= d_A + d_A .= 0 + end + if d_b !== prob_d_b + prob_d_b .+= d_b + d_b .= 0 + end + else + for i in 1:EnzymeRules.width(config) + if d_A !== prob_d_A[i] + prob_d_A[i] .+= d_A[i] + d_A[i] .= 0 + end + if d_b !== prob_d_b[i] + prob_d_b[i] .+= d_b[i] + d_b[i] .= 0 + end + end + end + + return (nothing, nothing) +end + +# y=inv(A) B +# dA −= z y^T +# dB += z, where z = inv(A^T) dy +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + res = func.val(linsolve.val; kwargs...) + + dres = if EnzymeRules.width(config) == 1 + deepcopy(res) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + deepcopy(res) + end + end + + if EnzymeRules.width(config) == 1 + dres.u .= 0 + else + for dr in dres + dr.u .= 0 + end + end + + resvals = if EnzymeRules.width(config) == 1 + dres.u + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + dres[i].u + end + end + + dAs = if EnzymeRules.width(config) == 1 + (linsolve.dval.A,) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + linsolve.dval[i].A + end + end + + dbs = if EnzymeRules.width(config) == 1 + (linsolve.dval.b,) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + linsolve.dval[i].b + end + end + + cachesolve = deepcopy(linsolve.val) + + cache = (copy(res.u), resvals, cachesolve, dAs, dbs) + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + y, dys, _linsolve, dAs, dbs = cache + + @assert !(typeof(linsolve) <: Const) + @assert !(typeof(linsolve) <: Active) + + if EnzymeRules.width(config) == 1 + dys = (dys,) + end + + for (dA, db, dy) in zip(dAs, dbs, dys) + z = if _linsolve.cacheval isa Factorization + _linsolve.cacheval' \ dy + elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization + _linsolve.cacheval[1]' \ dy + elseif _linsolve.alg isa AbstractKrylovSubspaceMethod + # Doesn't modify `A`, so it's safe to just reuse it + invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) + solve(invprob; + abstol = _linsolve.val.abstol, + reltol = _linsolve.val.reltol, + verbose = _linsolve.val.verbose) + else + error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") + end + + dA .-= z * transpose(y) + db .+= z + dy .= eltype(dy)(0) + end + + return (nothing,) +end + +end \ No newline at end of file diff --git a/src/init.jl b/src/init.jl index 2dccda626..360a2c86e 100644 --- a/src/init.jl +++ b/src/init.jl @@ -15,5 +15,8 @@ function __init__() @require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin include("../ext/LinearSolveMKLExt.jl") end + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/LinearSolveEnzymeExt.jl") + end end end diff --git a/test/basictests.jl b/test/basictests.jl index 42f283173..888f6322e 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -202,7 +202,7 @@ end end end - test_algs = if VERISON >= v"1.9" + test_algs = if VERSION >= v"1.9" (LUFactorization(), QRFactorization(), SVDFactorization(), diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 000000000..62904c055 --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,125 @@ +using Enzyme, ForwardDiff +using LinearSolve, LinearAlgebra, Test + +n = 4 +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); + +function f(A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + norm(s1) +end + +f(A, b1) # Uses BLAS + +Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) + +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 + +A = rand(n, n); +dA = zeros(n, n); +dA2 = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +db12 = zeros(n); + +#= +# Batch test fails +# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075 + +function fbatch(y, A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + y[1] = norm(s1) + nothing +end + +y = [0.0] +dy1 = [1.0] +dy2 = [1.0] +Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) + +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) + +@test_broken dA ≈ dA_2 +@test_broken dA2 ≈ dA_2 +@test_broken db1 ≈ db1_2 +@test_broken db12 ≈ db1_2 +=# + +function f(A, b1, b2; alg = LUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +b2 = rand(n); +db2 = zeros(n); + +f(A, b1, b2) +Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f2(A, b1, b2; alg = RFLUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f2(A, b1, b2) +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); +Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 ≈ db22 + +#= +function f3(A, b1, b2; alg = KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +@test dA ≈ dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 +=# \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 036bcf97e..4f2e78feb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,11 +8,12 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension) if GROUP == "All" || GROUP == "Core" @time @safetestset "Basic Tests" include("basictests.jl") - @time @safetestset "Re-solve" include("resolve.jl") + VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl") @time @safetestset "Zero Initialization Tests" include("zeroinittests.jl") @time @safetestset "Non-Square Tests" include("nonsquare.jl") @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") + VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") @time @safetestset "Traits" include("traits.jl") end