Skip to content

Commit

Permalink
Merge pull request #43 from SciML/ChrisRackauckas-patch-1
Browse files Browse the repository at this point in the history
Add Const activity for trailing data arguments
  • Loading branch information
ChrisRackauckas authored May 30, 2024
2 parents c8418c2 + e440023 commit b2251ec
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Const(f.f),
Enzyme.Duplicated(θ, res),
Const(p),
args...)
Const.(args)...)
end
end
else
Expand All @@ -43,7 +43,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
args...),
Const.(args)...),
return nothing
end
function hess(res, θ, args...)
Expand Down Expand Up @@ -77,13 +77,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
f,
Enzyme.Duplicated(x, dx),
Const(p),
args...)
Const.(args)...)
return dx
end
hv = function (H, θ, v, args...)
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
args...)[1]
Const.(args)...)[1]
end
else
hv = f.hv
Expand Down Expand Up @@ -168,7 +168,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
Const(f.f),
Enzyme.Duplicated(θ, res),
Const(p),
args...)
Const.(args)...)
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
Expand All @@ -179,7 +179,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Active, Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
args...)
Const.(args)...)
return nothing
end
function hess(res, θ, args...)
Expand All @@ -194,7 +194,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Const(p),
args...)
Const.(args)...)

for i in eachindex(θ)
res[i, :] .= vdbθ[i]
Expand All @@ -211,13 +211,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
f,
Enzyme.Duplicated(x, dx),
Const(p),
args...)
Const.(args)...)
return dx
end
hv = function (H, θ, v, args...)
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
Const(f.f), Const(p),
args...)[1]
Const.(args)...)[1]
end
else
hv = f.hv
Expand Down Expand Up @@ -285,7 +285,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Const(f.f),
Enzyme.Duplicated(θ, res),
Const(p),
args...)
Const.(args)...)
return res
end
end
Expand All @@ -301,7 +301,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
args...),
Const.(args)...),
return nothing
end
function hess(θ, args...)
Expand All @@ -316,7 +316,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Const(p),
args...)
Const.(args)...)

reduce(vcat, [reshape(vdbθ[i], (1, length(vdbθ[i]))) for i in eachindex(θ)])
end
Expand All @@ -334,13 +334,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
f,
Enzyme.Duplicated(x, dx),
Const(p),
args...)
Const.(args)...)
return dx
end
hv = function (θ, v, args...)
Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
args...)[1]
Const.(args)...)[1]
end
else
hv = f.hv
Expand Down Expand Up @@ -425,7 +425,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
Const(f.f),
Enzyme.Duplicated(θ, res),
Const(p),
args...)
Const.(args)...)
return res
end
end
Expand All @@ -441,7 +441,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
args...),
Const.(args)...),
return nothing
end
function hess(θ, args...)
Expand All @@ -456,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Const(p),
args...)
Const.(args)...)

reduce(vcat, [reshape(vdbθ[i], (1, length(vdbθ[i]))) for i in eachindex(θ)])
end
Expand All @@ -474,13 +474,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
f,
Enzyme.Duplicated(x, dx),
Const(p),
args...)
Const.(args)...)
return dx
end
hv = function (θ, v, args...)
Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
args...)[1]
Const.(args)...)[1]
end
else
hv = f.hv
Expand Down

0 comments on commit b2251ec

Please sign in to comment.