Skip to content

Commit

Permalink
Start symbolicanalysis integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Apr 11, 2024
1 parent b220e47 commit 1050fd7
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 4 deletions.
12 changes: 10 additions & 2 deletions ext/OptimizationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
lag_h, f.lag_hess_prototype)
lag_h = lag_h,
lag_hess_prototype = f.lag_hess_prototype,
sys = f.sys,
expr = f.expr,
cons_expr = f.cons_expr)
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x,
Expand Down Expand Up @@ -327,7 +331,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
lag_h, f.lag_hess_prototype)
lag_h = lag_h,
lag_hess_prototype = f.lag_hess_prototype,
sys = f.sys,
expr = f.expr,
cons_expr = f.cons_expr)
end

end
7 changes: 6 additions & 1 deletion src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ if !isdefined(Base, :get_extension)
end

using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra

using SymbolicIndexingInterface
using SymbolicAnalysis
import ModelingToolkit as MTK
import Symbolics
import Manifolds
import Symbolics: variable, Equation, Inequality, unwrap
import SciMLBase: OptimizationProblem,
OptimizationFunction, ObjSense,
MaxSense, MinSense, OptimizationStats
Expand Down
89 changes: 88 additions & 1 deletion src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,104 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C} <:
solver_args::NamedTuple
end

function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data;
function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFAULT_DATA;
callback = DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
progress = false,
mtkize = true,
kwargs...)
reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p)
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons)

if (f.sys === nothing || f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) && mtkize
try
if prob.u0 isa Matrix
n = size(prob.u0, 1)
m = size(prob.u0, 2)
vars = Symbolics.variables(:x, 1:n, 1:m)
else
vars = [variable(:x, i) for i in 1:length(prob.u0)]
end
@show typeof(vars)
params = if prob.p isa SciMLBase.NullParameters
[]
elseif prob.p isa MTK.MTKParameters
[variable(, i) for i in eachindex(vcat(p...))]
else
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
end
@show f.f
obj_expr = f.f(vars, params)

if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
lhs = Array{Num}(undef, num_cons)
f.cons(lhs, vars, params)
cons = Union{Equation, Inequality}[]

if !isnothing(prob.lcons)
for i in 1:num_cons
if !isinf(prob.lcons[i])
if prob.lcons[i] != prob.ucons[i]
push!(cons, prob.lcons[i] lhs[i])
else
push!(cons, lhs[i] ~ prob.ucons[i])
end
end
end
end

if !isnothing(prob.ucons)
for i in 1:num_cons
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
push!(cons, lhs[i] prob.ucons[i])
end
end
end
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
(isnothing(prob.ucons) || all(isinf, prob.ucons))
throw(ArgumentError("Constraints passed have no proper bounds defined.
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
or pass the lower and upper bounds for inequality constraints."))
end
elseif !isnothing(prob.f.cons)
cons_expr = f.cons(vars, params)
else
cons_expr = nothing
end
catch err
throw(ArgumentError("Automatic symbolic expression generation with ModelingToolkit failed with error: $err.
Try by setting `mtkize = false` instead if the solver doesn't require symbolic expressions."))
end
else
sys = f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing} ?
nothing : f.sys
obj_expr = f.expr
cons_expr = f.cons_expr
end
try
obj_expr = obj_expr |> Symbolics.unwrap
obj_expr = propagate_curvature(propagate_sign(obj_expr))
@show getcurvature(obj_expr)
catch
@info "No euclidean atom available"
end

try
obj_expr = SymbolicAnalysis.propagate_gcurvature(propagate_sign(obj_expr))
@show SymbolicAnalysis.getgcurvature(obj_expr)
catch
@info "No SPD atom available"
end

if !isnothing(cons_expr)
cons_expr = propagate_curvature(propagate_sign(cons_expr))
@show getcurvature(cons_expr)
end

return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,
prob.ucons, prob.sense,
opt, data, progress, callback,
Expand Down
10 changes: 10 additions & 0 deletions test/cvxtest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Optimization, OptimizationBase, ForwardDiff

function f(x, p = nothing)
return exp(x[1])
end

optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
prob = OptimizationProblem(optf, [0.4, 0.1, 0.3])

sol = solve(prob, Optimization.LBFGS(), maxiters = 1000)

0 comments on commit 1050fd7

Please sign in to comment.