diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index d0ec969..899b231 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -118,9 +118,7 @@ function OptimizationBase.instantiate_function( if f.cons === nothing cons = nothing else - function cons(res, θ) - return f.cons(res, θ, p) - end + cons = (res, θ) -> f.cons(res, θ, p) function cons_oop(x) _res = Zygote.Buffer(x, num_cons) @@ -369,7 +367,8 @@ function OptimizationBase.instantiate_function( end if hv == true && f.hv === nothing - prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) + prep_hvp = prepare_hvp( + f.f, soadtype.dense_ad, x, (zeros(eltype(x), size(x)),), Constant(p)) function hv!(H, θ, v) hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p)) end @@ -387,9 +386,7 @@ function OptimizationBase.instantiate_function( if f.cons === nothing cons = nothing else - function cons(res, θ) - f.cons(res, θ, p) - end + cons = (res, θ) -> f.cons(res, θ, p) function cons_oop(x) _res = Zygote.Buffer(x, num_cons) diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 1a19570..5acc1f7 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -14,17 +14,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, hvp, jacobian, Constant using ADTypes, SciMLBase -function generate_adtype(adtype) - if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode - soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) - else - soadtype = adtype - end - return adtype, soadtype -end - function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0; @@ -122,7 +111,10 @@ function instantiate_function( hv! = nothing end - if !(f.cons === nothing) + if f.cons === nothing + cons = nothing + else + cons = (res, x) -> f.cons(res, x, p) function cons_oop(x) _res = zeros(eltype(x), num_cons) f.cons(_res, x, p) @@ -257,7 +249,7 @@ function instantiate_function( return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, - cons = (res, x) -> f.cons(res, x, p), cons_j = cons_j!, cons_h = cons_h!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -379,7 +371,11 @@ function instantiate_function( hv! = nothing end - if !(f.cons === nothing) + if f.cons === nothing + cons = nothing + else + cons = Base.Fix2(f.cons, p) + function lagrangian(θ, σ, λ, p) return σ * f.f(θ, p) + dot(λ, f.cons(θ, p)) end @@ -482,7 +478,7 @@ function instantiate_function( return OptimizationFunction{false}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, - cons = Base.Fix2(f.cons, p), cons_j = cons_j!, cons_h = cons_h!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 5736012..128f545 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -12,96 +12,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, using ADTypes using SparseConnectivityTracer, SparseMatrixColorings -function generate_sparse_adtype(adtype) - if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && - adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = GreedyColoringAlgorithm()) - if adtype.dense_ad isa ADTypes.AutoFiniteDiff - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), - sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = GreedyColoringAlgorithm()) - elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), - sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), - sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = GreedyColoringAlgorithm()) - end - elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && - !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = adtype.coloring_algorithm) - if adtype.dense_ad isa ADTypes.AutoFiniteDiff - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), - sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), - sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), - sparsity_detector = TracerSparsityDetector(), - coloring_algorithm = adtype.coloring_algorithm) - end - elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && - adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = GreedyColoringAlgorithm()) - if adtype.dense_ad isa ADTypes.AutoFiniteDiff - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), - sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = GreedyColoringAlgorithm()) - elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), - sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = GreedyColoringAlgorithm()) - elseif !(adtype isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), - sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = GreedyColoringAlgorithm()) - end - else - if adtype.dense_ad isa ADTypes.AutoFiniteDiff - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), - sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), - sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse( - DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), - sparsity_detector = adtype.sparsity_detector, - coloring_algorithm = adtype.coloring_algorithm) - end - end - return adtype, soadtype -end - function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0; @@ -205,7 +115,11 @@ function instantiate_function( hv! = nothing end - if !(f.cons === nothing) + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + function cons_oop(x) _res = zeros(eltype(x), num_cons) f.cons(_res, x, p) @@ -347,7 +261,7 @@ function instantiate_function( end return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, - cons = (res, x) -> f.cons(res, x, p), cons_j = cons_j!, cons_h = cons_h!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -475,7 +389,11 @@ function instantiate_function( hv! = nothing end - if !(f.cons === nothing) + if f.cons === nothing + cons = nothing + else + cons = Base.Fix2(f.cons, p) + function lagrangian(θ, σ, λ, p) return σ * f.f(θ, p) + dot(λ, f.cons(θ, p)) end @@ -585,7 +503,7 @@ function instantiate_function( end return OptimizationFunction{false}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, - cons = Base.Fix2(f.cons, p), cons_j = cons_j!, cons_h = cons_h!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, diff --git a/src/adtypes.jl b/src/adtypes.jl index bfc0f2e..21e0025 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -218,3 +218,104 @@ if a `hess` function is supplied to the `OptimizationFunction`, then the Hessian is not defined via Zygote. """ AutoZygote + +function generate_adtype(adtype) + if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode + soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + else + soadtype = adtype + end + return adtype, soadtype +end + +function generate_sparse_adtype(adtype) + if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && + adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) + end + elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && + !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + end + elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && + adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + end + else + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = adtype.coloring_algorithm) + end + end + return adtype, soadtype +end