diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 9db9a1770..607189e10 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -30,29 +30,41 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dval.b for dval in dres) end - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b)) + + prob_d_A = if EnzymeRules.width(config) == 1 + prob.dval.A + else + (dval.A for dval in prob.dval) + end + prob_d_b = if EnzymeRules.width(config) == 1 + prob.dval.b + else + (dval.b for dval in prob.dval) + end + + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} - d_A, d_b = cache + d_A, d_b, prob_d_A, prob_d_b = cache if EnzymeRules.width(config) == 1 - if d_A !== prob.dval.A - prob.dval.A .+= d_A + if d_A !== prob_d_A + prob_d_A .+= d_A d_A .= 0 end - if d_b !== prob.dval.b - prob.dval.b .+= d_b + if d_b !== prob_d_b + prob_d_b .+= d_b d_b .= 0 end else for i in 1:EnzymeRules.width(config) - if d_A !== prob.dval.A - prob.dval.A[i] .+= d_A[i] + if d_A !== prob_d_A[i] + prob_d_A[i] .+= d_A[i] d_A[i] .= 0 end - if d_b !== prob.dval.b - prob.dval.b[i] .+= d_b[i] + if d_b !== prob_d_b[i] + prob_d_b[i] .+= d_b[i] d_b[i] .= 0 end end @@ -87,22 +99,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line resvals = if EnzymeRules.width(config) == 1 dres.u else - (dr.u for dr in dres) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + dres[i].u + end end dAs = if EnzymeRules.width(config) == 1 (linsolve.dval.A,) else - (dval.A for dval in linsolve.dval) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + linsolve.dval[i].A + end end dbs = if EnzymeRules.width(config) == 1 (linsolve.dval.b,) else - (dval.b for dval in linsolve.dval) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + linsolve.dval[i].b + end end - cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs) + cachesolve = deepcopy(linsolve.val) + + cache = (copy(res.u), resvals, cachesolve, dAs, dbs) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end