Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make zygote second order FD over Zygote #121

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
hvp, jacobian, Constant
using ADTypes, SciMLBase
import Zygote
import Zygote, Zygote.ForwardDiff

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x,
Expand Down
13 changes: 11 additions & 2 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ Hessian is not defined via Zygote.
AutoZygote

function generate_adtype(adtype)
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder)
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder ||
adtype isa AutoZygote)
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
elseif adtype isa AutoZygote
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
elseif adtype isa DifferentiationInterface.SecondOrder
soadtype = adtype
adtype = adtype.inner
Expand All @@ -234,11 +237,17 @@ end

function spadtype_to_spsoadtype(adtype)
if !(adtype.dense_ad isa SciMLBase.NoAD ||
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
adtype.dense_ad isa DifferentiationInterface.SecondOrder ||
adtype.dense_ad isa AutoZygote)
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif adtype.dense_ad isa AutoZygote
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
else
soadtype = adtype
end
Expand Down
11 changes: 9 additions & 2 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;

num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)

if !(prob.f.adtype isa DifferentiationInterface.SecondOrder) &&
if !(prob.f.adtype isa DifferentiationInterface.SecondOrder ||
prob.f.adtype isa AutoZygote) &&
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
SciMLBase.requireslagh(opt))
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
So a `SecondOrder` with $adtype for both inner and outer will be created, this can be suboptimal and not work in some cases so
So a `SecondOrder` with $(prob.f.adtype) for both inner and outer will be created, this can be suboptimal and not work in some cases so
an explicit `SecondOrder` ADtype is recommended."
elseif prob.f.adtype isa AutoZygote &&
(SciMLBase.requiresconshess(opt) || SciMLBase.requireslagh(opt) ||
SciMLBase.requireshessian(opt))
@warn "The selected optimization algorithm requires second order derivatives, but `AutoZygote` ADtype was provided.
So a `SecondOrder` with `AutoZygote` for inner and `AutoForwardDiff` for outer will be created, for choosing another pair
an explicit `SecondOrder` ADtype is recommended."
end

Expand Down
16 changes: 8 additions & 8 deletions test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ optprob.cons_h(H3, x0)
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = cons)
rosenbrock, AutoZygote(), cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 1, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -456,9 +456,9 @@ end
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = con2_c)
rosenbrock, AutoZygote(), cons = con2_c)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 2, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -1080,10 +1080,10 @@ end

cons = (x, p) -> [x[1]^2 + x[2]^2]
optf = OptimizationFunction{false}(rosenbrock,
SecondOrder(AutoForwardDiff(), AutoZygote()),
AutoZygote(),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 1, g = true, h = true, cons_j = true, cons_h = true)

@test optprob.grad(x0) == G1
Expand All @@ -1096,10 +1096,10 @@ end

cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
optf = OptimizationFunction{false}(rosenbrock,
SecondOrder(AutoForwardDiff(), AutoZygote()),
AutoZygote(),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 2, g = true, h = true, cons_j = true, cons_h = true)

@test optprob.grad(x0) == G1
Expand Down
Loading