Skip to content

Commit

Permalink
More caching
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent e4f0785 commit d69af77
Showing 1 changed file with 37 additions and 14 deletions.
51 changes: 37 additions & 14 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,41 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
(dval.b for dval in dres)

Check warning on line 30 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L30

Added line #L30 was not covered by tests
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))

prob_d_A = if EnzymeRules.width(config) == 1
prob.dval.A

Check warning on line 35 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
else
(dval.A for dval in prob.dval)

Check warning on line 37 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L37

Added line #L37 was not covered by tests
end
prob_d_b = if EnzymeRules.width(config) == 1
prob.dval.b

Check warning on line 40 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L39-L40

Added lines #L39 - L40 were not covered by tests
else
(dval.b for dval in prob.dval)

Check warning on line 42 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L42

Added line #L42 was not covered by tests
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))

Check warning on line 45 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L45

Added line #L45 was not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b = cache
d_A, d_b, prob_d_A, prob_d_b = cache

Check warning on line 49 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L48-L49

Added lines #L48 - L49 were not covered by tests

if EnzymeRules.width(config) == 1
if d_A !== prob.dval.A
prob.dval.A .+= d_A
if d_A !== prob_d_A
prob_d_A .+= d_A
d_A .= 0

Check warning on line 54 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L51-L54

Added lines #L51 - L54 were not covered by tests
end
if d_b !== prob.dval.b
prob.dval.b .+= d_b
if d_b !== prob_d_b
prob_d_b .+= d_b
d_b .= 0

Check warning on line 58 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L56-L58

Added lines #L56 - L58 were not covered by tests
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob.dval.A
prob.dval.A[i] .+= d_A[i]
if d_A !== prob_d_A[i]
prob_d_A[i] .+= d_A[i]
d_A[i] .= 0

Check warning on line 64 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L61-L64

Added lines #L61 - L64 were not covered by tests
end
if d_b !== prob.dval.b
prob.dval.b[i] .+= d_b[i]
if d_b !== prob_d_b[i]
prob_d_b[i] .+= d_b[i]
d_b[i] .= 0

Check warning on line 68 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end
end

Check warning on line 70 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L70

Added line #L70 was not covered by tests
Expand Down Expand Up @@ -87,22 +99,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
resvals = if EnzymeRules.width(config) == 1
dres.u

Check warning on line 100 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L99-L100

Added lines #L99 - L100 were not covered by tests
else
(dr.u for dr in dres)
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
dres[i].u

Check warning on line 104 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L102-L104

Added lines #L102 - L104 were not covered by tests
end
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)

Check warning on line 109 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L108-L109

Added lines #L108 - L109 were not covered by tests
else
(dval.A for dval in linsolve.dval)
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].A

Check warning on line 113 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L111-L113

Added lines #L111 - L113 were not covered by tests
end
end

dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)

Check warning on line 118 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L117-L118

Added lines #L117 - L118 were not covered by tests
else
(dval.b for dval in linsolve.dval)
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].b

Check warning on line 122 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L120-L122

Added lines #L120 - L122 were not covered by tests
end
end

cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs)
cachesolve = deepcopy(linsolve.val)

Check warning on line 126 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L126

Added line #L126 was not covered by tests

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 129 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end

Expand Down

0 comments on commit d69af77

Please sign in to comment.