Skip to content

Commit

Permalink
fix: improve hack supporting unscalarized usage of array observed var…
Browse files Browse the repository at this point in the history
…iables
  • Loading branch information
AayushSabharwal committed Oct 16, 2024
1 parent a87eb46 commit 74d69d8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
50 changes: 30 additions & 20 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -574,35 +574,35 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); subeqs]

# HACK: Substitute non-scalarized symbolic arrays of observed variables
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
# by the topological sorting and dependency identification pieces
obs_arr_subs = Dict()
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
# are equations, add an equation `p ~ [p[1], p[2], ...]`
# allow topsort to reorder them

handled_obs_arr = Set()
obs_arr_eqs = Equation[]
for eq in obs
lhs = eq.lhs
iscall(lhs) || continue
operation(lhs) === getindex || continue
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
arg1 = arguments(lhs)[1]
haskey(obs_arr_subs, arg1) && continue
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]]
index_first = eachindex(arg1)[1]

arg1 in handled_obs_arr && continue
# firstindex returns 1 for multidimensional array symbolics
firstind = first(eachindex(arg1))
scal = [arg1[i] for i in eachindex(arg1)]
# respect non-1-indexed arrays
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1])
end
for i in eachindex(neweqs)
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
end
for i in eachindex(obs)
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
end
for i in eachindex(subeqs)
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
end
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen
# try to `create_array(OffsetArray{...}, ...)` which errors.
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
# of `scal`.
push!(obs_arr_eqs, arg1 ~ change_origin(Origin(firstind), scal))
push!(handled_obs_arr, arg1)
end
append!(obs, obs_arr_eqs)
append!(subeqs, obs_arr_eqs)
# need to re-sort subeqs
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])

@set! sys.eqs = neweqs
@set! sys.observed = obs
Expand All @@ -629,6 +629,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
return invalidate_cache!(sys)
end

function change_origin(origin, arr)
return origin(arr)
end

@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin
size = size(arr)
eltype = eltype(arr)
ndims = ndims(arr)
end

function tearing(state::TearingState; kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
Expand Down
13 changes: 13 additions & 0 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,16 @@ end
@test ModelingToolkit.𝑠neighbors(g, 1) == [2]
@test ModelingToolkit.𝑑neighbors(g, 2) == [1]
end

@testset "array observed used unscalarized in another observed" begin
@variables x(t) y(t)[1:2] z(t)[1:2]
@parameters foo(::AbstractVector)[1:2]
_tmp_fn(x) = 2x
@mtkbuild sys = ODESystem([D(x) ~ z[1] + z[2], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
@test length(equations(sys)) == 1
@test length(observed(sys)) == 6
@test any(eq -> isequal(eq.lhs, y), observed(sys))
@test any(eq -> isequal(eq.lhs, z), observed(sys))
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
end

0 comments on commit 74d69d8

Please sign in to comment.