Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 1, 2024
1 parent c850508 commit 2aadf02
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 58 deletions.
6 changes: 4 additions & 2 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
end

function cons_f2(x, dx, fcons, p, num_cons, i)
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(num_cons), Const(i))
return nothing
end
Expand All @@ -83,7 +84,8 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
end

function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
return nothing
end
Expand Down
61 changes: 40 additions & 21 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ function OptimizationBase.instantiate_function(
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)

adtype, soadtype = OptimizationBase.generate_adtype(adtype)

if g == true && f.grad === nothing
Expand Down Expand Up @@ -83,12 +82,14 @@ function OptimizationBase.instantiate_function(

if fgh == true && f.fgh === nothing
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(f.f, G, H, prep_hess, soadtype, θ, Constant(p))
(y, _, _) = value_derivative_and_second_derivative!(
f.f, G, H, prep_hess, soadtype, θ, Constant(p))
return y
end
if p !== SciMLBase.NullParameters() && p !== nothing
function fgh!(G, H, θ, p)
(y, _, _) = value_derivative_and_second_derivative!(f.f, G, H, prep_hess, soadtype, θ, Constant(p))
(y, _, _) = value_derivative_and_second_derivative!(
f.f, G, H, prep_hess, soadtype, θ, Constant(p))
return y
end
end
Expand Down Expand Up @@ -180,7 +181,8 @@ function OptimizationBase.instantiate_function(
conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && cons_h == true && f.cons_h === nothing
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) for i in 1:num_cons]
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
for i in 1:num_cons]

function cons_h!(H, θ)
for i in 1:num_cons
Expand All @@ -197,20 +199,23 @@ function OptimizationBase.instantiate_function(

if f.lag_h === nothing && cons !== nothing && lag_h == true
lag_extras = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = zeros(Bool, num_cons, length(x))

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
cons_h!(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
hessian!(lagrangian, H, lag_extras, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

function lag_h!(h::AbstractVector, θ, σ, λ)
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
H = hessian(
lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand All @@ -226,12 +231,14 @@ function OptimizationBase.instantiate_function(
cons_h(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
hessian!(lagrangian, H, lag_extras, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

function lag_h!(h::AbstractVector, θ, σ, λ, p)
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
H = hessian(lagrangian, lag_extras, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand Down Expand Up @@ -303,12 +310,14 @@ function OptimizationBase.instantiate_function(
extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p))
end
function fg!(res, θ)
(y, _) = value_and_gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
(y, _) = value_and_gradient!(
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
return y
end
if p !== SciMLBase.NullParameters() && p !== nothing
function fg!(res, θ, p)
(y, _) = value_and_gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
(y, _) = value_and_gradient!(
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
return y
end
end
Expand Down Expand Up @@ -341,13 +350,15 @@ function OptimizationBase.instantiate_function(

if fgh == true && f.fgh === nothing
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, prep_hess, soadtype, Constant(p))
(y, _, _) = value_derivative_and_second_derivative!(
_f, G, H, θ, prep_hess, soadtype, Constant(p))
return y
end

if p !== SciMLBase.NullParameters() && p !== nothing
function fgh!(G, H, θ, p)
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, prep_hess, soadtype, Constant(p))
(y, _, _) = value_derivative_and_second_derivative!(
_f, G, H, θ, prep_hess, soadtype, Constant(p))
return y
end
end
Expand Down Expand Up @@ -419,7 +430,8 @@ function OptimizationBase.instantiate_function(
extras_pullback = prepare_pullback(
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),), Constant(p))
function cons_vjp!(J, θ, v)
pullback!(cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p))
pullback!(
cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p))
end
elseif cons_vjp == true
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
Expand All @@ -431,7 +443,8 @@ function OptimizationBase.instantiate_function(
extras_pushforward = prepare_pushforward(
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),), Constant(p))
function cons_jvp!(J, θ, v)
pushforward!(cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p))
pushforward!(
cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p))
end
elseif cons_jvp == true
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
Expand All @@ -442,7 +455,8 @@ function OptimizationBase.instantiate_function(
conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing && cons_h == true
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) for i in 1:num_cons]
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
for i in 1:num_cons]
colores = getfield.(prep_cons_hess, :coloring_result)
conshess_sparsity = getfield.(colores, :S)
conshess_colors = getfield.(colores, :color)
Expand All @@ -461,7 +475,8 @@ function OptimizationBase.instantiate_function(
lag_hess_colors = f.lag_hess_colorvec
if cons !== nothing && f.lag_h === nothing && lag_h == true
lag_extras = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = lag_extras.coloring_result.S
lag_hess_colors = lag_extras.coloring_result.color

Expand All @@ -470,12 +485,14 @@ function OptimizationBase.instantiate_function(
cons_h(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
hessian!(lagrangian, H, lag_extras, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

function lag_h!(h, θ, σ, λ)
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
H = hessian(
lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
k = 0
rows, cols, _ = findnz(H)
for (i, j) in zip(rows, cols)
Expand All @@ -492,12 +509,14 @@ function OptimizationBase.instantiate_function(
cons_h!(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
hessian!(lagrangian, H, lag_extras, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

function lag_h!(h::AbstractVector, θ, σ, λ, p)
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
H = hessian(lagrangian, lag_extras, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand Down
42 changes: 27 additions & 15 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ function instantiate_function(
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)

adtype, soadtype = generate_adtype(adtype)

if g == true && f.grad === nothing
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p))
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p))
function grad(res, θ)
gradient!(f.f, res, prep_grad, adtype, θ, Constant(p))
end
Expand Down Expand Up @@ -183,7 +182,8 @@ function instantiate_function(
conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if f.cons !== nothing && f.cons_h === nothing && cons_h == true
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) for i in 1:num_cons]
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
for i in 1:num_cons]

function cons_h!(H, θ)
for i in 1:num_cons
Expand All @@ -200,20 +200,23 @@ function instantiate_function(

if f.cons !== nothing && lag_h == true && f.lag_h === nothing
lag_prep = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = zeros(Bool, num_cons, length(x))

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
cons_h!(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
hessian!(lagrangian, H, lag_prep, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

function lag_h!(h::AbstractVector, θ, σ, λ)
H = hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
H = hessian(
lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand All @@ -229,12 +232,14 @@ function instantiate_function(
cons_h!(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
hessian!(lagrangian, H, lag_prep, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

function lag_h!(h::AbstractVector, θ, σ, λ, p)
H = hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
H = hessian(lagrangian, lag_prep, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
k = 0
for i in 1:length(θ)
for j in 1:i
Expand Down Expand Up @@ -341,12 +346,14 @@ function instantiate_function(

if fgh == true && f.fgh === nothing
function fgh!(θ)
(y, G, H) = value_derivative_and_second_derivative(f.f, prep_hess, adtype, θ, Constant(p))
(y, G, H) = value_derivative_and_second_derivative(
f.f, prep_hess, adtype, θ, Constant(p))
return y, G, H
end
if p !== SciMLBase.NullParameters() && p !== nothing
function fgh!(θ, p)
(y, G, H) = value_derivative_and_second_derivative(f.f, prep_hess, adtype, θ, Constant(p))
(y, G, H) = value_derivative_and_second_derivative(
f.f, prep_hess, adtype, θ, Constant(p))
return y, G, H
end
end
Expand Down Expand Up @@ -396,7 +403,8 @@ function instantiate_function(
end

if f.cons_vjp === nothing && cons_vjp == true && f.cons !== nothing
prep_pullback = prepare_pullback(f.cons, adtype, x, (ones(eltype(x), num_cons),), Constant(p))
prep_pullback = prepare_pullback(
f.cons, adtype, x, (ones(eltype(x), num_cons),), Constant(p))
function cons_vjp!(θ, v)
return only(pullback(f.cons, prep_pullback, adtype, θ, (v,), Constant(p)))
end
Expand Down Expand Up @@ -424,7 +432,8 @@ function instantiate_function(
function cons_i(x, i)
return f.cons(x, p)[i]
end
prep_cons_hess = [prepare_hessian(cons_i, soadtype, x, Constant(i)) for i in 1:num_cons]
prep_cons_hess = [prepare_hessian(cons_i, soadtype, x, Constant(i))
for i in 1:num_cons]

function cons_h!(θ)
H = map(1:num_cons) do i
Expand All @@ -442,14 +451,16 @@ function instantiate_function(

if f.cons !== nothing && lag_h == true && f.lag_h === nothing
lag_prep = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = zeros(Bool, num_cons, length(x))

function lag_h!(θ, σ, λ)
if σ == zero(eltype(θ))
return λ .* cons_h(θ)
else
return hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
return hessian(lagrangian, lag_prep, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end

Expand All @@ -458,7 +469,8 @@ function instantiate_function(
if σ == zero(eltype(θ))
return λ .* cons_h(θ)
else
return hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
return hessian(lagrangian, lag_prep, soadtype, θ,
Constant(σ), Constant(λ), Constant(p))
end
end
end
Expand Down
Loading

0 comments on commit 2aadf02

Please sign in to comment.