Skip to content

Commit

Permalink
Actually test sparse zygote
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 1, 2024
1 parent 2aadf02 commit 75be96e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "2.0.4"
version = "2.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
48 changes: 24 additions & 24 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,13 @@ function OptimizationBase.instantiate_function(
adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype)

if g == true && f.grad === nothing
extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p))
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p))
function grad(res, θ)
gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
end
if p !== SciMLBase.NullParameters() && p !== nothing
function grad(res, θ, p)
gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
end
end
elseif g == true
Expand All @@ -307,17 +307,17 @@ function OptimizationBase.instantiate_function(

if fg == true && f.fg === nothing
if g == false
extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p))
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p))
end
function fg!(res, θ)
(y, _) = value_and_gradient!(
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
return y
end
if p !== SciMLBase.NullParameters() && p !== nothing
function fg!(res, θ, p)
(y, _) = value_and_gradient!(
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
return y
end
end
Expand All @@ -330,16 +330,16 @@ function OptimizationBase.instantiate_function(
hess_sparsity = f.hess_prototype
hess_colors = f.hess_colorvec
if h == true && f.hess === nothing
prep_hess = prepare_hessian(_f, soadtype, x, Constant(p))
prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p))
function hess(res, θ)
hessian!(_f, res, prep_hess, soadtype, θ, Constant(p))
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
end
hess_sparsity = extras_hess.coloring_result.S
hess_colors = extras_hess.coloring_result.color
hess_sparsity = prep_hess.coloring_result.S
hess_colors = prep_hess.coloring_result.color

if p !== SciMLBase.NullParameters() && p !== nothing
function hess(res, θ, p)
hessian!(_f, res, prep_hess, soadtype, θ, Constant(p))
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
end
end
elseif h == true
Expand All @@ -351,14 +351,14 @@ function OptimizationBase.instantiate_function(
if fgh == true && f.fgh === nothing
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(
_f, G, H, θ, prep_hess, soadtype, Constant(p))
f.f, G, H, θ, prep_hess, soadtype, Constant(p))
return y
end

if p !== SciMLBase.NullParameters() && p !== nothing
function fgh!(G, H, θ, p)
(y, _, _) = value_derivative_and_second_derivative!(
_f, G, H, θ, prep_hess, soadtype, Constant(p))
f.f, G, H, θ, prep_hess, soadtype, Constant(p))
return y
end
end
Expand All @@ -371,11 +371,11 @@ function OptimizationBase.instantiate_function(
if hv == true && f.hv === nothing
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
function hv!(H, θ, v)
hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
end
if p !== SciMLBase.NullParameters() && p !== nothing
function hv!(H, θ, v, p)
hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
end
end
elseif hv == true
Expand Down Expand Up @@ -411,15 +411,15 @@ function OptimizationBase.instantiate_function(
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && cons_j == true && f.cons_j === nothing
prep_jac = prepare_jacobian(cons_oop, adtype, x, Constant(p))
prep_jac = prepare_jacobian(cons_oop, adtype, x)
function cons_j!(J, θ)
jacobian!(cons_oop, J, prep_jac, adtype, θ, Constant(p))
jacobian!(cons_oop, J, prep_jac, adtype, θ)
if size(J, 1) == 1
J = vec(J)
end
end
cons_jac_prototype = extras_jac.coloring_result.S
cons_jac_colorvec = extras_jac.coloring_result.color
cons_jac_prototype = prep_jac.coloring_result.S
cons_jac_colorvec = prep_jac.coloring_result.color
elseif cons !== nothing && cons_j == true
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
else
Expand All @@ -428,10 +428,10 @@ function OptimizationBase.instantiate_function(

if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
extras_pullback = prepare_pullback(
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),), Constant(p))
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),))
function cons_vjp!(J, θ, v)
pullback!(
cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p))
cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,))
end
elseif cons_vjp == true
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
Expand All @@ -441,10 +441,10 @@ function OptimizationBase.instantiate_function(

if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
extras_pushforward = prepare_pushforward(
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),), Constant(p))
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),))
function cons_jvp!(J, θ, v)
pushforward!(
cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p))
cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,))
end
elseif cons_jvp == true
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
Expand Down Expand Up @@ -482,7 +482,7 @@ function OptimizationBase.instantiate_function(

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
cons_h(H, θ)
cons_h!(H, θ)
H *= λ
else
hessian!(lagrangian, H, lag_extras, soadtype, θ,
Expand Down
2 changes: 1 addition & 1 deletion test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ end

# Instantiate the optimization problem
optprob = OptimizationBase.instantiate_function(optf, x0,
OptimizationBase.AutoSparseForwardDiff(),
OptimizationBase.AutoSparseZygote(),
nothing, 2, g = true, h = true, cons_j = true, cons_h = true, lag_h = true)
# Test gradient
G = zeros(3)
Expand Down

0 comments on commit 75be96e

Please sign in to comment.