Skip to content

Commit

Permalink
Merge pull request #112 from SciML/DIv6
Browse files Browse the repository at this point in the history
Move iterator checking here and make symbolics stuff extension
  • Loading branch information
Vaibhavdixit02 authored Oct 2, 2024
2 parents 1ebca6f + a25ebfb commit 0edc12b
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 127 deletions.
13 changes: 7 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "2.1.0"
version = "2.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -17,24 +17,27 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
OptimizationEnzymeExt = "Enzyme"
OptimizationFiniteDiffExt = "FiniteDiff"
OptimizationForwardDiffExt = "ForwardDiff"
OptimizationMLDataDevicesExt = "MLDataDevices"
OptimizationMLUtilsExt = "MLUtils"
OptimizationMTKExt = "ModelingToolkit"
OptimizationReverseDiffExt = "ReverseDiff"
OptimizationSymbolicAnalysisExt = "SymbolicAnalysis"
OptimizationZygoteExt = "Zygote"

[compat]
Expand All @@ -56,8 +59,6 @@ SciMLBase = "2"
SparseConnectivityTracer = "0.6"
SparseMatrixColorings = "0.4"
SymbolicAnalysis = "0.3"
SymbolicIndexingInterface = "0.3"
Symbolics = "5.12, 6"
Zygote = "0.6.67"
julia = "1.10"

Expand Down
8 changes: 8 additions & 0 deletions ext/OptimizationMLDataDevicesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationMLDataDevicesExt

using MLDataDevices
using OptimizationBase

OptimizationBase.isa_dataiterator(::DeviceIterator) = true

end
8 changes: 8 additions & 0 deletions ext/OptimizationMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationMLUtilsExt

using MLUtils
using OptimizationBase

OptimizationBase.isa_dataiterator(::MLUtils.DataLoader) = true

end
109 changes: 109 additions & 0 deletions ext/OptimizationSymbolicAnalysisExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
module OptimizationSymbolicAnalysisExt

using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics
using SymbolicAnalysis: AnalysisResult
import Symbolics: variable, Equation, Inequality, unwrap, @variables

function OptimizationBase.symify_cache(
f::OptimizationFunction{iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP,
CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV},
prob) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
EX <: Nothing, CEX <: Nothing, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}
try
vars = if prob.u0 isa Matrix
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
else
ArrayInterface.restructure(
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
end
params = if prob.p isa SciMLBase.NullParameters
[]
elseif prob.p isa MTK.MTKParameters
[variable(, i) for i in eachindex(vcat(p...))]
else
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
end

if prob.u0 isa Matrix
vars = vars[1]
end

obj_expr = f.f(vars, params)

if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
lhs = Array{Symbolics.Num}(undef, num_cons)
f.cons(lhs, vars)
cons = Union{Equation, Inequality}[]

if !isnothing(prob.lcons)
for i in 1:num_cons
if !isinf(prob.lcons[i])
if prob.lcons[i] != prob.ucons[i]
push!(cons, prob.lcons[i] lhs[i])
else
push!(cons, lhs[i] ~ prob.ucons[i])
end
end
end
end

if !isnothing(prob.ucons)
for i in 1:num_cons
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
push!(cons, lhs[i] prob.ucons[i])
end
end
end
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
(isnothing(prob.ucons) || all(isinf, prob.ucons))
throw(ArgumentError("Constraints passed have no proper bounds defined.
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
or pass the lower and upper bounds for inequality constraints."))
end
cons_expr = lhs
elseif !isnothing(prob.f.cons)
cons_expr = f.cons(vars, params)
else
cons_expr = nothing
end
catch err
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
end
return obj_expr, cons_expr
end

function analysis(obj_expr, cons_expr)
if obj_expr !== nothing
obj_expr = obj_expr |> Symbolics.unwrap
if manifold === nothing
obj_res = analyze(obj_expr)
else
obj_res = analyze(obj_expr, manifold)
end
@info "Objective Euclidean curvature: $(obj_res.curvature)"
if obj_res.gcurvature !== nothing
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
end
end

if cons_expr !== nothing
cons_expr = cons_expr .|> Symbolics.unwrap
if manifold === nothing
cons_res = analyze.(cons_expr)
else
cons_res = analyze.(cons_expr, Ref(manifold))
end
for i in 1:num_cons
@info "Constraints Euclidean curvature: $(cons_res[i].curvature)"

if cons_res[i].gcurvature !== nothing
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
end
end
end

return obj_res, cons_res
end

end
8 changes: 4 additions & 4 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ function OptimizationBase.instantiate_function(
function hess(res, θ)
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
end
hess_sparsity = prep_hess.coloring_result.S
hess_sparsity = prep_hess.coloring_result.A
hess_colors = prep_hess.coloring_result.color

if p !== SciMLBase.NullParameters() && p !== nothing
Expand Down Expand Up @@ -415,7 +415,7 @@ function OptimizationBase.instantiate_function(
J = vec(J)
end
end
cons_jac_prototype = prep_jac.coloring_result.S
cons_jac_prototype = prep_jac.coloring_result.A
cons_jac_colorvec = prep_jac.coloring_result.color
elseif cons !== nothing && cons_j == true
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
Expand Down Expand Up @@ -455,7 +455,7 @@ function OptimizationBase.instantiate_function(
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
for i in 1:num_cons]
colores = getfield.(prep_cons_hess, :coloring_result)
conshess_sparsity = getfield.(colores, :S)
conshess_sparsity = getfield.(colores, :A)
conshess_colors = getfield.(colores, :color)
function cons_h!(H, θ)
for i in 1:num_cons
Expand All @@ -474,7 +474,7 @@ function OptimizationBase.instantiate_function(
lag_extras = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = lag_extras.coloring_result.S
lag_hess_prototype = lag_extras.coloring_result.A
lag_hess_colors = lag_extras.coloring_result.color

function lag_h!(H::AbstractMatrix, θ, σ, λ)
Expand Down
6 changes: 1 addition & 5 deletions src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ if !isdefined(Base, :get_extension)
end

using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
using SymbolicIndexingInterface
using SymbolicAnalysis
using SymbolicAnalysis: AnalysisResult
import Symbolics
import Symbolics: variable, Equation, Inequality, unwrap, @variables
import SciMLBase: OptimizationProblem,
OptimizationFunction, ObjSense,
MaxSense, MinSense, OptimizationStats
Expand All @@ -31,6 +26,7 @@ Base.iterate(::NullData, i = 1) = nothing
Base.length(::NullData) = 0

include("adtypes.jl")
include("symify.jl")
include("cache.jl")
include("OptimizationDIExt.jl")
include("OptimizationDISparseExt.jl")
Expand Down
16 changes: 8 additions & 8 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function instantiate_function(
function hess(res, θ)
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
end
hess_sparsity = prep_hess.coloring_result.S
hess_sparsity = prep_hess.coloring_result.A
hess_colors = prep_hess.coloring_result.color

if p !== SciMLBase.NullParameters() && p !== nothing
Expand Down Expand Up @@ -147,7 +147,7 @@ function instantiate_function(
J = vec(J)
end
end
cons_jac_prototype = prep_jac.coloring_result.S
cons_jac_prototype = prep_jac.coloring_result.A
cons_jac_colorvec = prep_jac.coloring_result.color
elseif cons_j === true && f.cons !== nothing
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
Expand Down Expand Up @@ -185,7 +185,7 @@ function instantiate_function(
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
for i in 1:num_cons]
colores = getfield.(prep_cons_hess, :coloring_result)
conshess_sparsity = getfield.(colores, :S)
conshess_sparsity = getfield.(colores, :A)
conshess_colors = getfield.(colores, :color)
function cons_h!(H, θ)
for i in 1:num_cons
Expand All @@ -204,7 +204,7 @@ function instantiate_function(
lag_prep = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = lag_prep.coloring_result.S
lag_hess_prototype = lag_prep.coloring_result.A
lag_hess_colors = lag_prep.coloring_result.color

function lag_h!(H::AbstractMatrix, θ, σ, λ)
Expand Down Expand Up @@ -357,7 +357,7 @@ function instantiate_function(
function hess(θ)
hessian(f.f, prep_hess, soadtype, θ, Constant(p))
end
hess_sparsity = prep_hess.coloring_result.S
hess_sparsity = prep_hess.coloring_result.A
hess_colors = prep_hess.coloring_result.color

if p !== SciMLBase.NullParameters() && p !== nothing
Expand Down Expand Up @@ -410,7 +410,7 @@ function instantiate_function(
end
return J
end
cons_jac_prototype = prep_jac.coloring_result.S
cons_jac_prototype = prep_jac.coloring_result.A
cons_jac_colorvec = prep_jac.coloring_result.color
elseif cons_j === true && f.cons !== nothing
cons_j! = (θ) -> f.cons_j(θ, p)
Expand Down Expand Up @@ -459,7 +459,7 @@ function instantiate_function(
return H
end
colores = getfield.(prep_cons_hess, :coloring_result)
conshess_sparsity = getfield.(colores, :S)
conshess_sparsity = getfield.(colores, :A)
conshess_colors = getfield.(colores, :color)
elseif cons_h == true && f.cons !== nothing
cons_h! = (res, θ) -> f.cons_h(res, θ, p)
Expand All @@ -482,7 +482,7 @@ function instantiate_function(
return hess
end
end
lag_hess_prototype = lag_prep.coloring_result.S
lag_hess_prototype = lag_prep.coloring_result.A
lag_hess_colors = lag_prep.coloring_result.color

if p !== SciMLBase.NullParameters() && p !== nothing
Expand Down
Loading

0 comments on commit 0edc12b

Please sign in to comment.