diff --git a/docs/src/basics/MTKLanguage.md b/docs/src/basics/MTKLanguage.md index 685c549429..7e1f6abb98 100644 --- a/docs/src/basics/MTKLanguage.md +++ b/docs/src/basics/MTKLanguage.md @@ -73,7 +73,7 @@ end v_array(t)[1:N, 1:M] v_for_defaults(t) end - @extend ModelB(; p1) + @extend ModelB(p1 = 1) @components begin model_a = ModelA(; k_array) model_array_a = [ModelA(; k = i) for i in 1:N] @@ -149,14 +149,18 @@ julia> ModelingToolkit.getdefault(model_c1.v) #### `@extend` begin block - - Partial systems can be extended in a higher system as `@extend PartialSystem(; kwargs)`. - - Keyword arguments pf partial system in the `@extend` definition are added as the keyword arguments of the base system. - - Note that in above example, `p1` is promoted as an argument of `ModelC`. Users can set the value of `p1`. However, as `p2` isn't listed in the model definition, its initial guess can't be specified while creating an instance of `ModelC`. +Partial systems can be extended in a higher system in two ways: -```julia -julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0) + - `@extend PartialSystem(var1 = value1)` + + + This is the recommended way of extending a base system. + + The default values for the arguments of the base system can be declared in the `@extend` statement. + + Note that all keyword arguments of the base system are added as the keyword arguments of the main system. -``` + - `@extend var_to_unpack1, var_to_unpack2 = partial_sys = PartialSystem(var1 = value1)` + + + In this method: explicitly list the variables that should be unpacked, provide a name for the partial system and declare the base system. + + Note that only the arguments listed out in the declaration of the base system and unpacked variables (here: `var1`, `var_to_unpack1`, `var_to_unpack2`) are added as the keyword arguments of the higher system. #### `@components` begin block @@ -325,11 +329,11 @@ For example, the structure of `ModelC` is: julia> ModelC.structure Dict{Symbol, Any} with 10 entries: :components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]] - :variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_for_defaults=>Dict(:type=>Real)) + :variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real, :size=>(:N, :M)), :v_for_defaults=>Dict(:type=>Real)) :icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png") - :kwargs => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3), :v => Dict{Symbol, Any}(:value => :v_var, :type => Real), :v_for_defaults => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), :p1 => Dict(:value => nothing)), - :structural_parameters => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3)) - :independent_variable => t + :kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :p2=>Dict(:value=>NoValue()), :N=>Dict(:value=>2), :M=>Dict(:value=>3), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Any}(:value=>nothing, :type=>Real, :size=>(:N, :M)), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>1)) + :structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :M=>Dict(:value=>3)) + :independent_variable => :t :constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant.")) :extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB] :defaults => Dict{Symbol, Any}(:v_for_defaults=>2.0) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 1298d72506..61749f944e 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -717,7 +717,7 @@ function parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs) end end -function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false) +function extend_args!(a, b, dict, expr, kwargs, has_param = false) # Whenever `b` is a function call, skip the first arg aka the function name. # Whenever it is a kwargs list, include it. start = b.head == :call ? 2 : 1 @@ -734,22 +734,18 @@ function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false) b.args[i] = Expr(:parameters, x) end end - push!(kwargs, Expr(:kw, x, nothing)) dict[:kwargs][x] = Dict(:value => nothing) end Expr(:kw, x) => begin - push!(kwargs, Expr(:kw, x, nothing)) dict[:kwargs][x] = Dict(:value => nothing) end Expr(:kw, x, y) => begin - b.args[i] = Expr(:kw, x, x) - push!(varexpr.args, :($x = $x === nothing ? $y : $x)) - push!(kwargs, Expr(:kw, x, nothing)) - dict[:kwargs][x] = Dict(:value => nothing) + push!(kwargs, Expr(:kw, x, y)) + dict[:kwargs][x] = Dict(:value => y) end Expr(:parameters, x...) => begin has_param = true - extend_args!(a, arg, dict, expr, kwargs, varexpr, has_param) + extend_args!(a, arg, dict, expr, kwargs, has_param) end _ => error("Could not parse $arg of component $a") end @@ -758,17 +754,31 @@ end const EMPTY_DICT = Dict() const EMPTY_VoVoSYMBOL = Vector{Symbol}[] +const EMPTY_VoVoVoSYMBOL = Vector{Symbol}[[]] -function Base.names(model::Model) +function _arguments(model::Model) vars = keys(get(model.structure, :variables, EMPTY_DICT)) vars = union(vars, keys(get(model.structure, :parameters, EMPTY_DICT))) - vars = union(vars, - map(first, get(model.structure, :components, EMPTY_VoVoSYMBOL))) + vars = union(vars, first(get(model.structure, :extend, EMPTY_VoVoVoSYMBOL))) collect(vars) end -function _parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars) - extend_args!(a, b, dict, expr, kwargs, varexpr) +function Base.names(model::Model) + collect(union(_arguments(model), + map(first, get(model.structure, :components, EMPTY_VoVoSYMBOL)))) +end + +function _parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args) + extend_args!(a, b, dict, expr, kwargs) + b.args = [b.args[1]] + allvars = union(vars.args, additional_args.args) + for var in allvars + push!(b.args, Expr(:kw, var, var)) + if !haskey(dict[:kwargs], var) + push!(dict[:kwargs], var => Dict(:value => NO_VALUE)) + push!(kwargs, Expr(:kw, var, NO_VALUE)) + end + end ext[] = a push!(b.args, Expr(:kw, :name, Meta.quot(a))) push!(expr.args, :($a = $b)) @@ -780,8 +790,6 @@ end function parse_extend!(exprs, ext, dict, mod, body, kwargs) expr = Expr(:block) - varexpr = Expr(:block) - push!(exprs, varexpr) push!(exprs, expr) body = deepcopy(body) MLStyle.@match body begin @@ -792,7 +800,7 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs) error("`@extend` destructuring only takes an tuple as LHS. Got $body") end a, b = b.args - _parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars) + _parse_extend!(ext, a, b, dict, expr, kwargs, vars, Expr(:tuple)) else error("When explicitly destructing in `@extend` please use the syntax: `@extend a, b = oneport = OnePort()`.") end @@ -802,8 +810,11 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs) b = body if (model = getproperty(mod, b.args[1])) isa Model vars = Expr(:tuple) - append!(vars.args, names(model)) - _parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars) + append!(vars.args, _arguments(model)) + additional_args = Expr(:tuple) + append!(additional_args.args, + keys(get(model.structure, :structural_parameters, EMPTY_DICT))) + _parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args) else error("Cannot infer the exact `Model` that `@extend $(body)` refers." * " Please specify the names that it brings into scope by:" * @@ -1104,7 +1115,7 @@ function parse_icon!(body::String, dict, icon, mod) icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons")) dict[:icon] = icon[] = if isfile(body) URI("file:///" * abspath(body)) - elseif (iconpath = joinpath(icon_dir, body); isfile(iconpath)) + elseif (iconpath = abspath(joinpath(icon_dir, body)); isfile(iconpath)) URI("file:///" * abspath(iconpath)) elseif try Base.isvalid(URI(body)) @@ -1115,6 +1126,7 @@ function parse_icon!(body::String, dict, icon, mod) elseif (_body = lstrip(body); startswith(_body, r"<\?xml| Dict(:type => Real)) @test A.structure[:extend] == [[:e], :extended_e, :E] @test A.structure[:equations] == ["e ~ 0"] - @test A.structure[:kwargs] == - Dict{Symbol, Dict}(:p => Dict(:value => nothing, :type => Real), - :v => Dict(:value => nothing, :type => Real)) + @test A.structure[:kwargs] == Dict{Symbol, Dict}( + :p => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), + :e => Dict(:value => ModelingToolkit.NoValue()), + :v => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real)) @test A.structure[:components] == [[:cc, :C]] end @@ -910,3 +911,25 @@ end end), false) end + +@mtkmodel BaseSys begin + @parameters begin + p1 + p2 + end + @variables begin + v1(t) + end +end + +@testset "Arguments of base system" begin + @mtkmodel MainSys begin + @extend BaseSys(p1 = 1) + end + + @test names(MainSys) == [:p2, :p1, :v1] + @named main_sys = MainSys(p1 = 11, p2 = 12, v1 = 13) + @test getdefault(main_sys.p1) == 11 + @test getdefault(main_sys.p2) == 12 + @test getdefault(main_sys.v1) == 13 +end