diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 367257c3c..83c26db8f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,6 +14,7 @@ jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: group: - Core @@ -47,4 +48,4 @@ jobs: with: file: lcov.info token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + fail_ci_if_error: false diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index aa6abbdf2..4f5ff06b8 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -7,8 +7,10 @@ using SciMLBase using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution -using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index +using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, + observed, parameter_values, state_values, current_time using RecursiveArrayTools +import SciMLStructures # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -109,7 +111,18 @@ end @adjoint function Base.getindex(VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - if i === nothing + if is_observed(VA, sym) + f = observed(VA, sym) + p = parameter_values(VA) + tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) + u = state_values(VA) + t = current_time(VA) + y, back = Zygote.pullback(u, tunables) do u, tunables + f.(u, Ref(tunables), t) + end + gs = back(Δ) + (gs[1], nothing) + elseif i === nothing throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] @@ -120,26 +133,49 @@ end VA[sym], ODESolution_getindex_pullback end +function obs_grads(VA, sym, obs_idx, Δ) + y, back = Zygote.pullback(VA) do sol + getindex.(Ref(sol), sym[obs_idx]) + end + Δreduced = reduce(hcat, Δ) + Δobs = eachrow(Δreduced[obs_idx, :]) + back(Δobs) +end + +function obs_grads(VA, sym, ::Nothing, Δ) + Zygote.nt_nothing(VA) +end + +function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T} + Δ′ = map(enumerate(VA.u)) do (t_idx, us) + map(enumerate(us)) do (u_idx, u) + if u_idx in i + idx = findfirst(isequal(u_idx), i) + Δ[t_idx][idx] + else + zero(T) + end + end + end + + Δ′ +end + @adjoint function Base.getindex( VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T} function ODESolution_getindex_pullback(Δ) sym = sym isa Tuple ? collect(sym) : sym i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) - if i === nothing - throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) - else - Δ′ = map(enumerate(VA.u)) do (t_idx, us) - map(enumerate(us)) do (u_idx, u) - if u_idx in i - idx = findfirst(isequal(u_idx), i) - Δ[t_idx][idx] - else - zero(T) - end - end - end - (Δ′, nothing) - end + + obs_idx = findall(s -> is_observed(VA, s), sym) + not_obs_idx = setdiff(1:length(sym), obs_idx) + + gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ) + gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ) + + a = Zygote.accum(gs_obs[1], gs_not_obs) + + (a, nothing) end VA[sym], ODESolution_getindex_pullback end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index ae9d41b80..11ce0ff96 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -3,6 +3,7 @@ BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1" @@ -22,6 +23,7 @@ BoundaryValueDiffEq = "5" ForwardDiff = "0.10" JumpProcesses = "9.10" ModelingToolkit = "8.37, 9" +ModelingToolkitStandardLibrary = "2.7" NonlinearSolve = "2, 3" Optimization = "3" OptimizationMOI = "0.4" diff --git a/test/downstream/observables_autodiff.jl b/test/downstream/observables_autodiff.jl new file mode 100644 index 000000000..974afa150 --- /dev/null +++ b/test/downstream/observables_autodiff.jl @@ -0,0 +1,102 @@ +using ModelingToolkit, OrdinaryDiffEq +using Zygote +using ModelingToolkit: t_nounits as t, D_nounits as D +import SymbolicIndexingInterface as SII +import SciMLStructures as SS +using ModelingToolkitStandardLibrary +import ModelingToolkitStandardLibrary as MSL + +@parameters σ ρ β +@variables x(t) y(t) z(t) w(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z, + w ~ x + y + z + 2 * β] + +@mtkbuild sys = ODESystem(eqs, t) + +u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + +tspan = (0.0, 100.0) +prob = ODEProblem(sys, u0, tspan, p, jac = true) +sol = solve(prob, Tsit5()) + +@testset "AutoDiff Observable Functions" begin + gs, = gradient(sol) do sol + sum(sol[sys.w]) + end + du_ = [0.0, 1.0, 1.0, 1.0] + du = [du_ for _ in sol.u] + @test du == gs + + # Observable in a vector + gs, = gradient(sol) do sol + sum(sum.(sol[[sys.w, sys.x]])) + end + du_ = [0.0, 1.0, 1.0, 2.0] + du = [du_ for _ in sol.u] + @test du == gs +end + +# DAE + +function create_model(; C₁ = 3e-5, C₂ = 1e-6) + @variables t + @named resistor1 = MSL.Electrical.Resistor(R = 5.0) + @named resistor2 = MSL.Electrical.Resistor(R = 2.0) + @named capacitor1 = MSL.Electrical.Capacitor(C = C₁) + @named capacitor2 = MSL.Electrical.Capacitor(C = C₂) + @named source = MSL.Electrical.Voltage() + @named input_signal = MSL.Blocks.Sine(frequency = 100.0) + @named ground = MSL.Electrical.Ground() + @named ampermeter = MSL.Electrical.CurrentSensor() + + eqs = [connect(input_signal.output, source.V) + connect(source.p, capacitor1.n, capacitor2.n) + connect(source.n, resistor1.p, resistor2.p, ground.g) + connect(resistor1.n, capacitor1.p, ampermeter.n) + connect(resistor2.n, capacitor2.p, ampermeter.p)] + + @named circuit_model = ODESystem(eqs, t, + systems = [ + resistor1, resistor2, capacitor1, capacitor2, + source, input_signal, ground, ampermeter + ]) +end + +@testset "DAE Observable function AD" begin + model = create_model() + sys = structural_simplify(model) + + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Rodas4()) + + gs, = gradient(sol) do sol + sum(sol[sys.ampermeter.i]) + end + du_ = [0.2, 1.0] + du = [du_ for _ in sol.u] + @test gs == du +end + +# @testset "Adjoints with DAE" begin +# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables +# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables) +# new_prob = remake(prob, p = new_p) +# sol = solve(new_prob, Rodas4()) +# @show size(sol) +# # mean(abs.(sol[sys.ampermeter.i] .- gt)) +# sum(sol[sys.ampermeter.i]) +# end +# +# @test isnothing(gs_mtkp) +# @test length(gs_p_new) == length(p_new) +# end diff --git a/test/runtests.jl b/test/runtests.jl index 5e0bed4c7..f5b3c2d22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,6 +110,9 @@ end @time @safetestset "Partial Functions" begin include("downstream/partial_functions.jl") end + @time @safetestset "Autodiff Observable Functions" begin + include("downstream/observables_autodiff.jl") + end end if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")