Skip to content

Commit

Permalink
cons again
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 1, 2024
1 parent 75be96e commit 3038abf
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 116 deletions.
11 changes: 4 additions & 7 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ function OptimizationBase.instantiate_function(
if f.cons === nothing
cons = nothing
else
function cons(res, θ)
return f.cons(res, θ, p)
end
cons = (res, θ) -> f.cons(res, θ, p)

function cons_oop(x)
_res = Zygote.Buffer(x, num_cons)
Expand Down Expand Up @@ -369,7 +367,8 @@ function OptimizationBase.instantiate_function(
end

if hv == true && f.hv === nothing
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
prep_hvp = prepare_hvp(
f.f, soadtype.dense_ad, x, (zeros(eltype(x), size(x)),), Constant(p))
function hv!(H, θ, v)
hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
end
Expand All @@ -387,9 +386,7 @@ function OptimizationBase.instantiate_function(
if f.cons === nothing
cons = nothing
else
function cons(res, θ)
f.cons(res, θ, p)
end
cons = (res, θ) -> f.cons(res, θ, p)

function cons_oop(x)
_res = Zygote.Buffer(x, num_cons)
Expand Down
26 changes: 11 additions & 15 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
hvp, jacobian, Constant
using ADTypes, SciMLBase

function generate_adtype(adtype)
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
else
soadtype = adtype
end
return adtype, soadtype
end

function instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType,
p = SciMLBase.NullParameters(), num_cons = 0;
Expand Down Expand Up @@ -122,7 +111,10 @@ function instantiate_function(
hv! = nothing
end

if !(f.cons === nothing)
if f.cons === nothing
cons = nothing
else
cons = (res, x) -> f.cons(res, x, p)
function cons_oop(x)
_res = zeros(eltype(x), num_cons)
f.cons(_res, x, p)
Expand Down Expand Up @@ -257,7 +249,7 @@ function instantiate_function(

return OptimizationFunction{true}(f.f, adtype;
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
cons = (res, x) -> f.cons(res, x, p), cons_j = cons_j!, cons_h = cons_h!,
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
hess_prototype = hess_sparsity,
hess_colorvec = hess_colors,
Expand Down Expand Up @@ -379,7 +371,11 @@ function instantiate_function(
hv! = nothing
end

if !(f.cons === nothing)
if f.cons === nothing
cons = nothing
else
cons = Base.Fix2(f.cons, p)

function lagrangian(θ, σ, λ, p)
return σ * f.f(θ, p) + dot(λ, f.cons(θ, p))
end
Expand Down Expand Up @@ -482,7 +478,7 @@ function instantiate_function(

return OptimizationFunction{false}(f.f, adtype;
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
cons = Base.Fix2(f.cons, p), cons_j = cons_j!, cons_h = cons_h!,
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
hess_prototype = hess_sparsity,
hess_colorvec = hess_colors,
Expand Down
106 changes: 12 additions & 94 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,96 +12,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
using ADTypes
using SparseConnectivityTracer, SparseMatrixColorings

function generate_sparse_adtype(adtype)
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
end
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
end
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
end
else
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
end
end
return adtype, soadtype
end

function instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType},
p = SciMLBase.NullParameters(), num_cons = 0;
Expand Down Expand Up @@ -205,7 +115,11 @@ function instantiate_function(
hv! = nothing
end

if !(f.cons === nothing)
if f.cons === nothing
cons = nothing
else
cons = (res, θ) -> f.cons(res, θ, p)

function cons_oop(x)
_res = zeros(eltype(x), num_cons)
f.cons(_res, x, p)
Expand Down Expand Up @@ -347,7 +261,7 @@ function instantiate_function(
end
return OptimizationFunction{true}(f.f, adtype;
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
cons = (res, x) -> f.cons(res, x, p), cons_j = cons_j!, cons_h = cons_h!,
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
hess_prototype = hess_sparsity,
hess_colorvec = hess_colors,
Expand Down Expand Up @@ -475,7 +389,11 @@ function instantiate_function(
hv! = nothing
end

if !(f.cons === nothing)
if f.cons === nothing
cons = nothing
else
cons = Base.Fix2(f.cons, p)

function lagrangian(θ, σ, λ, p)
return σ * f.f(θ, p) + dot(λ, f.cons(θ, p))
end
Expand Down Expand Up @@ -585,7 +503,7 @@ function instantiate_function(
end
return OptimizationFunction{false}(f.f, adtype;
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
cons = Base.Fix2(f.cons, p), cons_j = cons_j!, cons_h = cons_h!,
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
hess_prototype = hess_sparsity,
hess_colorvec = hess_colors,
Expand Down
101 changes: 101 additions & 0 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,104 @@ if a `hess` function is supplied to the `OptimizationFunction`, then the
Hessian is not defined via Zygote.
"""
AutoZygote

function generate_adtype(adtype)
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
else
soadtype = adtype
end
return adtype, soadtype
end

function generate_sparse_adtype(adtype)
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
end
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
end
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
end
else
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
end
end
return adtype, soadtype
end

0 comments on commit 3038abf

Please sign in to comment.