From 9f8d18fbb6a13fbb592579c7b809b1ff14acc282 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 20:49:13 -0500 Subject: [PATCH] Add actual file --- ext/LinearSolveEnzymeExt.jl | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 ext/LinearSolveEnzymeExt.jl diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl new file mode 100644 index 000000000..f38cf56e2 --- /dev/null +++ b/ext/LinearSolveEnzymeExt.jl @@ -0,0 +1,38 @@ +module LinearSolveEnzymeExt + +using LinearSolve +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + + +using Enzyme + +using EnzymeCore + +# y=inv(A) B +# dA −= z y^T +# dB += z, where z = inv(A^T) dy +function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} + res = func.val(prob.val, alg.val; kwargs...) + dres = deepcopy(res) + dres.u .= 0 + cache = (copy(prob.val.A), res, dres.u) + return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} + A, y, dy = cache + + dA = prob.dval.A + db = prob.dval.b + + invprob = LinearProblem(transpose(A), dy) + + z = func.val(invprob, alg; kwargs...) + + dA .-= z * transpose(y) + db .+= z + dy .= 0 + return (nothing, nothing) +end + +end \ No newline at end of file