diff --git a/ext/OptimizationForwardDiffExt.jl b/ext/OptimizationForwardDiffExt.jl index 256cebc..eb361c1 100644 --- a/ext/OptimizationForwardDiffExt.jl +++ b/ext/OptimizationForwardDiffExt.jl @@ -165,7 +165,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, cons_hess_prototype = f.cons_hess_prototype, - lag_h, f.lag_hess_prototype) + lag_h = lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, @@ -327,7 +331,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, cons_hess_prototype = f.cons_hess_prototype, - lag_h, f.lag_hess_prototype) + lag_h = lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end end diff --git a/src/OptimizationBase.jl b/src/OptimizationBase.jl index cf4cd0c..72713b9 100644 --- a/src/OptimizationBase.jl +++ b/src/OptimizationBase.jl @@ -9,7 +9,12 @@ if !isdefined(Base, :get_extension) end using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra - +using SymbolicIndexingInterface +using SymbolicAnalysis +import ModelingToolkit as MTK +import Symbolics +import Manifolds +import Symbolics: variable, Equation, Inequality, unwrap import SciMLBase: OptimizationProblem, OptimizationFunction, ObjSense, MaxSense, MinSense, OptimizationStats diff --git a/src/cache.jl b/src/cache.jl index ead31e2..24e9713 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -14,17 +14,98 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C} <: solver_args::NamedTuple end -function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data; +function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFAULT_DATA; callback = DEFAULT_CALLBACK, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, abstol::Union{Number, Nothing} = nothing, reltol::Union{Number, Nothing} = nothing, progress = false, + mtkize = true, kwargs...) reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p) num_cons = prob.ucons === nothing ? 0 : length(prob.ucons) f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons) + + if (f.sys === nothing || f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) && mtkize + try + vars = ArrayInterface.restructure(prob.u0, [variable(:x, i) for i in eachindex(prob.u0)]) + @show typeof(vars) + params = if prob.p isa SciMLBase.NullParameters + [] + elseif prob.p isa MTK.MTKParameters + [variable(:α, i) for i in eachindex(vcat(p...))] + else + ArrayInterface.restructure(p, [variable(:α, i) for i in eachindex(p)]) + end + @show f.f + obj_expr = f.f(vars, params) + + if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons) + lhs = Array{Num}(undef, num_cons) + f.cons(lhs, vars, params) + cons = Union{Equation, Inequality}[] + + if !isnothing(prob.lcons) + for i in 1:num_cons + if !isinf(prob.lcons[i]) + if prob.lcons[i] != prob.ucons[i] + push!(cons, prob.lcons[i] ≲ lhs[i]) + else + push!(cons, lhs[i] ~ prob.ucons[i]) + end + end + end + end + + if !isnothing(prob.ucons) + for i in 1:num_cons + if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i] + push!(cons, lhs[i] ≲ prob.ucons[i]) + end + end + end + if (isnothing(prob.lcons) || all(isinf, prob.lcons)) && + (isnothing(prob.ucons) || all(isinf, prob.ucons)) + throw(ArgumentError("Constraints passed have no proper bounds defined. + Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints + or pass the lower and upper bounds for inequality constraints.")) + end + elseif !isnothing(prob.f.cons) + cons_expr = f.cons(vars, params) + else + cons_expr = nothing + end + catch err + throw(ArgumentError("Automatic symbolic expression generation with ModelingToolkit failed with error: $err. + Try by setting `mtkize = false` instead if the solver doesn't require symbolic expressions.")) + end + else + sys = f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing} ? + nothing : f.sys + obj_expr = f.expr + cons_expr = f.cons_expr + end + try + obj_expr = obj_expr |> Symbolics.unwrap + obj_expr = propagate_curvature(propagate_sign(obj_expr)) + @show getcurvature(obj_expr) + catch + @info "No euclidean atom available" + end + + try + obj_expr = SymbolicAnalysis.propagate_gcurvature(propagate_sign(obj_expr)) + @show SymbolicAnalysis.getgcurvature(obj_expr) + catch + @info "No SPD atom available" + end + + if !isnothing(cons_expr) + cons_expr = propagate_curvature(propagate_sign(cons_expr)) + @show getcurvature(cons_expr) + end + return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons, prob.ucons, prob.sense, opt, data, progress, callback, diff --git a/test/cvxtest.jl b/test/cvxtest.jl new file mode 100644 index 0000000..54920eb --- /dev/null +++ b/test/cvxtest.jl @@ -0,0 +1,10 @@ +using Optimization, OptimizationBase, ForwardDiff + +function f(x, p = nothing) + return exp(x[1]) + x[1]^2 +end + +optf = OptimizationFunction(f, Optimization.AutoForwardDiff()) +prob = OptimizationProblem(optf, [0.4]) + +sol = solve(prob, Optimization.LBFGS(), maxiters = 1000) \ No newline at end of file