Skip to content

Commit

Permalink
Add actual file
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and ChrisRackauckas committed Sep 22, 2023
1 parent bb6d623 commit 9f8d18f
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module LinearSolveEnzymeExt

using LinearSolve
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = deepcopy(res)
dres.u .= 0
cache = (copy(prob.val.A), res, dres.u)
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache)

Check warning on line 19 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L14-L19

Added lines #L14 - L19 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
A, y, dy = cache

Check warning on line 23 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L22-L23

Added lines #L22 - L23 were not covered by tests

dA = prob.dval.A
db = prob.dval.b

Check warning on line 26 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L25-L26

Added lines #L25 - L26 were not covered by tests

invprob = LinearProblem(transpose(A), dy)

Check warning on line 28 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L28

Added line #L28 was not covered by tests

z = func.val(invprob, alg; kwargs...)

Check warning on line 30 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L30

Added line #L30 was not covered by tests

dA .-= z * transpose(y)
db .+= z
dy .= 0
return (nothing, nothing)

Check warning on line 35 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L32-L35

Added lines #L32 - L35 were not covered by tests
end

end

0 comments on commit 9f8d18f

Please sign in to comment.