diff --git a/Project.toml b/Project.toml index 1fcfffa..3deeafb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimizationBase" uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" authors = ["Vaibhav Dixit and contributors"] -version = "2.0.4" +version = "2.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index f31be4b..d0ec969 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -290,13 +290,13 @@ function OptimizationBase.instantiate_function( adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype) if g == true && f.grad === nothing - extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p)) + extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p)) function grad(res, θ) - gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) + gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) end if p !== SciMLBase.NullParameters() && p !== nothing function grad(res, θ, p) - gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) + gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) end end elseif g == true @@ -307,17 +307,17 @@ function OptimizationBase.instantiate_function( if fg == true && f.fg === nothing if g == false - extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p)) + extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p)) end function fg!(res, θ) (y, _) = value_and_gradient!( - _f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) + f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fg!(res, θ, p) (y, _) = value_and_gradient!( - _f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) + f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) return y end end @@ -330,16 +330,16 @@ function OptimizationBase.instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if h == true && f.hess === nothing - prep_hess = prepare_hessian(_f, soadtype, x, Constant(p)) + prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p)) function hess(res, θ) - hessian!(_f, res, prep_hess, soadtype, θ, Constant(p)) + hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p)) end - hess_sparsity = extras_hess.coloring_result.S - hess_colors = extras_hess.coloring_result.color + hess_sparsity = prep_hess.coloring_result.S + hess_colors = prep_hess.coloring_result.color if p !== SciMLBase.NullParameters() && p !== nothing function hess(res, θ, p) - hessian!(_f, res, prep_hess, soadtype, θ, Constant(p)) + hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p)) end end elseif h == true @@ -351,14 +351,14 @@ function OptimizationBase.instantiate_function( if fgh == true && f.fgh === nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!( - _f, G, H, θ, prep_hess, soadtype, Constant(p)) + f.f, G, H, θ, prep_hess, soadtype, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(G, H, θ, p) (y, _, _) = value_derivative_and_second_derivative!( - _f, G, H, θ, prep_hess, soadtype, Constant(p)) + f.f, G, H, θ, prep_hess, soadtype, Constant(p)) return y end end @@ -371,11 +371,11 @@ function OptimizationBase.instantiate_function( if hv == true && f.hv === nothing prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) function hv!(H, θ, v) - hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p)) + hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p)) end if p !== SciMLBase.NullParameters() && p !== nothing function hv!(H, θ, v, p) - hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p)) + hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p)) end end elseif hv == true @@ -411,15 +411,15 @@ function OptimizationBase.instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && cons_j == true && f.cons_j === nothing - prep_jac = prepare_jacobian(cons_oop, adtype, x, Constant(p)) + prep_jac = prepare_jacobian(cons_oop, adtype, x) function cons_j!(J, θ) - jacobian!(cons_oop, J, prep_jac, adtype, θ, Constant(p)) + jacobian!(cons_oop, J, prep_jac, adtype, θ) if size(J, 1) == 1 J = vec(J) end end - cons_jac_prototype = extras_jac.coloring_result.S - cons_jac_colorvec = extras_jac.coloring_result.color + cons_jac_prototype = prep_jac.coloring_result.S + cons_jac_colorvec = prep_jac.coloring_result.color elseif cons !== nothing && cons_j == true cons_j! = (J, θ) -> f.cons_j(J, θ, p) else @@ -428,10 +428,10 @@ function OptimizationBase.instantiate_function( if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing extras_pullback = prepare_pullback( - cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),), Constant(p)) + cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),)) function cons_vjp!(J, θ, v) pullback!( - cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p)) + cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,)) end elseif cons_vjp == true cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) @@ -441,10 +441,10 @@ function OptimizationBase.instantiate_function( if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing extras_pushforward = prepare_pushforward( - cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),), Constant(p)) + cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),)) function cons_jvp!(J, θ, v) pushforward!( - cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p)) + cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,)) end elseif cons_jvp == true cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) @@ -482,7 +482,7 @@ function OptimizationBase.instantiate_function( function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) - cons_h(H, θ) + cons_h!(H, θ) H *= λ else hessian!(lagrangian, H, lag_extras, soadtype, θ, diff --git a/test/adtests.jl b/test/adtests.jl index 2872cf1..7b11a09 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -739,7 +739,7 @@ end # Instantiate the optimization problem optprob = OptimizationBase.instantiate_function(optf, x0, - OptimizationBase.AutoSparseForwardDiff(), + OptimizationBase.AutoSparseZygote(), nothing, 2, g = true, h = true, cons_j = true, cons_h = true, lag_h = true) # Test gradient G = zeros(3)