Skip to content

Commit

Permalink
feat: extend all arguments of a base sys to sys
Browse files Browse the repository at this point in the history
  • Loading branch information
ven-k committed Oct 15, 2024
1 parent 28a5af3 commit ef3a7f2
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 34 deletions.
26 changes: 15 additions & 11 deletions docs/src/basics/MTKLanguage.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ end
v_array(t)[1:N, 1:M]
v_for_defaults(t)
end
@extend ModelB(; p1)
@extend ModelB(p1 = 1)
@components begin
model_a = ModelA(; k_array)
model_array_a = [ModelA(; k = i) for i in 1:N]
Expand Down Expand Up @@ -149,14 +149,18 @@ julia> ModelingToolkit.getdefault(model_c1.v)

#### `@extend` begin block

- Partial systems can be extended in a higher system as `@extend PartialSystem(; kwargs)`.
- Keyword arguments pf partial system in the `@extend` definition are added as the keyword arguments of the base system.
- Note that in above example, `p1` is promoted as an argument of `ModelC`. Users can set the value of `p1`. However, as `p2` isn't listed in the model definition, its initial guess can't be specified while creating an instance of `ModelC`.
Partial systems can be extended in a higher system in two ways:

```julia
julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0)
- `@extend PartialSystem(var1 = value1)`

+ This is the recommended way of extending a base system.
+ The default values for the arguments of the base system can be declared in the `@extend` statement.
+ Note that all keyword arguments of the base system are added as the keyword arguments of the main system.

```
- `@extend var_to_unpack1, var_to_unpack2 = partial_sys = PartialSystem(var1 = value1)`

+ In this method: explicitly list the variables that should be unpacked, provide a name for the partial system and declare the base system.
+ Note that only the arguments listed out in the declaration of the base system and unpacked variables (here: `var1`, `var_to_unpack1`, `var_to_unpack2`) are added as the keyword arguments of the higher system.

#### `@components` begin block

Expand Down Expand Up @@ -325,11 +329,11 @@ For example, the structure of `ModelC` is:
julia> ModelC.structure
Dict{Symbol, Any} with 10 entries:
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_for_defaults=>Dict(:type=>Real))
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real, :size=>(:N, :M)), :v_for_defaults=>Dict(:type=>Real))
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
:kwargs => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3), :v => Dict{Symbol, Any}(:value => :v_var, :type => Real), :v_for_defaults => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), :p1 => Dict(:value => nothing)),
:structural_parameters => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3))
:independent_variable => t
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :p2=>Dict(:value=>NoValue()), :N=>Dict(:value=>2), :M=>Dict(:value=>3), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Any}(:value=>nothing, :type=>Real, :size=>(:N, :M)), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>1))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :M=>Dict(:value=>3))
:independent_variable => :t
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
:defaults => Dict{Symbol, Any}(:v_for_defaults=>2.0)
Expand Down
50 changes: 31 additions & 19 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ function parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs)
end
end

function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
function extend_args!(a, b, dict, expr, kwargs, has_param = false)
# Whenever `b` is a function call, skip the first arg aka the function name.
# Whenever it is a kwargs list, include it.
start = b.head == :call ? 2 : 1
Expand All @@ -734,22 +734,18 @@ function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
b.args[i] = Expr(:parameters, x)
end
end
push!(kwargs, Expr(:kw, x, nothing))
dict[:kwargs][x] = Dict(:value => nothing)
end
Expr(:kw, x) => begin
push!(kwargs, Expr(:kw, x, nothing))
dict[:kwargs][x] = Dict(:value => nothing)
end
Expr(:kw, x, y) => begin
b.args[i] = Expr(:kw, x, x)
push!(varexpr.args, :($x = $x === nothing ? $y : $x))
push!(kwargs, Expr(:kw, x, nothing))
dict[:kwargs][x] = Dict(:value => nothing)
push!(kwargs, Expr(:kw, x, y))
dict[:kwargs][x] = Dict(:value => y)
end
Expr(:parameters, x...) => begin
has_param = true
extend_args!(a, arg, dict, expr, kwargs, varexpr, has_param)
extend_args!(a, arg, dict, expr, kwargs, has_param)
end
_ => error("Could not parse $arg of component $a")
end
Expand All @@ -758,17 +754,31 @@ end

const EMPTY_DICT = Dict()
const EMPTY_VoVoSYMBOL = Vector{Symbol}[]
const EMPTY_VoVoVoSYMBOL = Vector{Symbol}[[]]

function Base.names(model::Model)
function _arguments(model::Model)
vars = keys(get(model.structure, :variables, EMPTY_DICT))
vars = union(vars, keys(get(model.structure, :parameters, EMPTY_DICT)))
vars = union(vars,
map(first, get(model.structure, :components, EMPTY_VoVoSYMBOL)))
vars = union(vars, first(get(model.structure, :extend, EMPTY_VoVoVoSYMBOL)))
collect(vars)
end

function _parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
extend_args!(a, b, dict, expr, kwargs, varexpr)
function Base.names(model::Model)
collect(union(_arguments(model),
map(first, get(model.structure, :components, EMPTY_VoVoSYMBOL))))
end

function _parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args)
extend_args!(a, b, dict, expr, kwargs)
b.args = [b.args[1]]
allvars = union(vars.args, additional_args.args)
for var in allvars
push!(b.args, Expr(:kw, var, var))
if !haskey(dict[:kwargs], var)
push!(dict[:kwargs], var => Dict(:value => NO_VALUE))
push!(kwargs, Expr(:kw, var, NO_VALUE))
end
end
ext[] = a
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
push!(expr.args, :($a = $b))
Expand All @@ -780,8 +790,6 @@ end

function parse_extend!(exprs, ext, dict, mod, body, kwargs)
expr = Expr(:block)
varexpr = Expr(:block)
push!(exprs, varexpr)
push!(exprs, expr)
body = deepcopy(body)
MLStyle.@match body begin
Expand All @@ -792,7 +800,7 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
end
a, b = b.args
_parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
_parse_extend!(ext, a, b, dict, expr, kwargs, vars, Expr(:tuple))
else
error("When explicitly destructing in `@extend` please use the syntax: `@extend a, b = oneport = OnePort()`.")
end
Expand All @@ -802,8 +810,11 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
b = body
if (model = getproperty(mod, b.args[1])) isa Model
vars = Expr(:tuple)
append!(vars.args, names(model))
_parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
append!(vars.args, _arguments(model))
additional_args = Expr(:tuple)
append!(additional_args.args,
keys(get(model.structure, :structural_parameters, EMPTY_DICT)))
_parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args)
else
error("Cannot infer the exact `Model` that `@extend $(body)` refers." *
" Please specify the names that it brings into scope by:" *
Expand Down Expand Up @@ -1104,7 +1115,7 @@ function parse_icon!(body::String, dict, icon, mod)
icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons"))
dict[:icon] = icon[] = if isfile(body)
URI("file:///" * abspath(body))
elseif (iconpath = joinpath(icon_dir, body); isfile(iconpath))
elseif (iconpath = abspath(joinpath(icon_dir, body)); isfile(iconpath))
URI("file:///" * abspath(iconpath))
elseif try
Base.isvalid(URI(body))
Expand All @@ -1115,6 +1126,7 @@ function parse_icon!(body::String, dict, icon, mod)
elseif (_body = lstrip(body); startswith(_body, r"<\?xml|<svg"))
String(_body) # With Julia-1.10 promoting `SubString{String}` to `String` can be dropped.
else
@info iconpath=joinpath(icon_dir, body) isfile(iconpath) body
error("\n$body is not a valid icon")
end
end
Expand Down
31 changes: 27 additions & 4 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ R_val = 20u"Ω"
res__R = 100u"Ω"
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
prob = ODEProblem(rc, [], (0, 1e9))
sol = solve(prob, Rodas5P())
sol = solve(prob)
defs = ModelingToolkit.defaults(rc)
@test sol[rc.capacitor.v, end] defs[rc.constant.k]
resistor = getproperty(rc, :resistor; namespace = false)
Expand Down Expand Up @@ -459,9 +459,10 @@ end
@test A.structure[:parameters] == Dict(:p => Dict(:type => Real))
@test A.structure[:extend] == [[:e], :extended_e, :E]
@test A.structure[:equations] == ["e ~ 0"]
@test A.structure[:kwargs] ==
Dict{Symbol, Dict}(:p => Dict(:value => nothing, :type => Real),
:v => Dict(:value => nothing, :type => Real))
@test A.structure[:kwargs] == Dict{Symbol, Dict}(
:p => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real),
:e => Dict(:value => ModelingToolkit.NoValue()),
:v => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real))
@test A.structure[:components] == [[:cc, :C]]
end

Expand Down Expand Up @@ -910,3 +911,25 @@ end
end),
false)
end

@mtkmodel BaseSys begin
@parameters begin
p1
p2
end
@variables begin
v1(t)
end
end

@testset "Arguments of base system" begin
@mtkmodel MainSys begin
@extend BaseSys(p1 = 1)
end

@test names(MainSys) == [:p2, :p1, :v1]
@named main_sys = MainSys(p1 = 11, p2 = 12, v1 = 13)
@test getdefault(main_sys.p1) == 11
@test getdefault(main_sys.p2) == 12
@test getdefault(main_sys.v1) == 13
end

0 comments on commit ef3a7f2

Please sign in to comment.