From b83a251a33434dd2ccdd6aea73a69959119720b1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 11 Oct 2024 14:39:56 +0530 Subject: [PATCH] fix: make `vars` search through arrays of symbolics --- src/utils.jl | 6 ++++++ test/variable_utils.jl | 10 +++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 29e61ede69..71b7bc3cd4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -385,6 +385,12 @@ function vars!(vars, O; op = Differential) if isvariable(O) return push!(vars, O) end + if symbolic_type(O) == NotSymbolic() && O isa AbstractArray + for arg in O + vars!(vars, arg; op) + end + return vars + end !iscall(O) && return vars operation(O) isa op && return push!(vars, O) diff --git a/test/variable_utils.jl b/test/variable_utils.jl index 8f3178f453..d76f2f1209 100644 --- a/test/variable_utils.jl +++ b/test/variable_utils.jl @@ -1,5 +1,5 @@ using ModelingToolkit, Test -using ModelingToolkit: value +using ModelingToolkit: value, vars using SymbolicUtils: <ₑ @parameters α β δ expr = (((1 / β - 1) + δ) / α)^(1 / (α - 1)) @@ -33,3 +33,11 @@ aov = ModelingToolkit.collect_applied_operators(eq, Differential) ts = collect_ivs([eq]) @test ts == Set([t]) + +@testset "vars searching through array of symbolics" begin + fn(x, y) = sum(x) + y + @register_symbolic fn(x::AbstractArray, y) + @variables x y z + res = vars(fn([x, y], z)) + @test length(res) == 3 +end