Skip to content

Commit

Permalink
Merge pull request #689 from DhairyaLGandhi/dg/obsfn
Browse files Browse the repository at this point in the history
Feat: adjoints through observable functions
  • Loading branch information
ChrisRackauckas authored May 25, 2024
2 parents 5b172dc + f817b52 commit 3811745
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
group:
- Core
Expand Down Expand Up @@ -47,4 +48,4 @@ jobs:
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
70 changes: 53 additions & 17 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ using SciMLBase
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
observed, parameter_values, state_values, current_time
using RecursiveArrayTools
import SciMLStructures

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -109,7 +111,18 @@ end
@adjoint function Base.getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
if is_observed(VA, sym)
f = observed(VA, sym)
p = parameter_values(VA)
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
u = state_values(VA)
t = current_time(VA)
y, back = Zygote.pullback(u, tunables) do u, tunables
f.(u, Ref(tunables), t)
end
gs = back(Δ)
(gs[1], nothing)
elseif i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
Expand All @@ -120,26 +133,49 @@ end
VA[sym], ODESolution_getindex_pullback
end

function obs_grads(VA, sym, obs_idx, Δ)
y, back = Zygote.pullback(VA) do sol
getindex.(Ref(sol), sym[obs_idx])
end
Δreduced = reduce(hcat, Δ)
Δobs = eachrow(Δreduced[obs_idx, :])
back(Δobs)
end

function obs_grads(VA, sym, ::Nothing, Δ)
Zygote.nt_nothing(VA)
end

function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
idx = findfirst(isequal(u_idx), i)
Δ[t_idx][idx]
else
zero(T)
end
end
end

Δ′
end

@adjoint function Base.getindex(
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
idx = findfirst(isequal(u_idx), i)
Δ[t_idx][idx]
else
zero(T)
end
end
end
(Δ′, nothing)
end

obs_idx = findall(s -> is_observed(VA, s), sym)
not_obs_idx = setdiff(1:length(sym), obs_idx)

gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)

a = Zygote.accum(gs_obs[1], gs_not_obs)

(a, nothing)
end
VA[sym], ODESolution_getindex_pullback
end
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
Expand All @@ -22,6 +23,7 @@ BoundaryValueDiffEq = "5"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "8.37, 9"
ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3"
Optimization = "3"
OptimizationMOI = "0.4"
Expand Down
102 changes: 102 additions & 0 deletions test/downstream/observables_autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using ModelingToolkit, OrdinaryDiffEq
using Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
import SymbolicIndexingInterface as SII
import SciMLStructures as SS
using ModelingToolkitStandardLibrary
import ModelingToolkitStandardLibrary as MSL

@parameters σ ρ β
@variables x(t) y(t) z(t) w(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x *- z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z + 2 * β]

@mtkbuild sys = ODESystem(eqs, t)

u0 = [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]

p ==> 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())

@testset "AutoDiff Observable Functions" begin
gs, = gradient(sol) do sol
sum(sol[sys.w])
end
du_ = [0.0, 1.0, 1.0, 1.0]
du = [du_ for _ in sol.u]
@test du == gs

# Observable in a vector
gs, = gradient(sol) do sol
sum(sum.(sol[[sys.w, sys.x]]))
end
du_ = [0.0, 1.0, 1.0, 2.0]
du = [du_ for _ in sol.u]
@test du == gs
end

# DAE

function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@variables t
@named resistor1 = MSL.Electrical.Resistor(R = 5.0)
@named resistor2 = MSL.Electrical.Resistor(R = 2.0)
@named capacitor1 = MSL.Electrical.Capacitor(C = C₁)
@named capacitor2 = MSL.Electrical.Capacitor(C = C₂)
@named source = MSL.Electrical.Voltage()
@named input_signal = MSL.Blocks.Sine(frequency = 100.0)
@named ground = MSL.Electrical.Ground()
@named ampermeter = MSL.Electrical.CurrentSensor()

eqs = [connect(input_signal.output, source.V)
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]

@named circuit_model = ODESystem(eqs, t,
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter
])
end

@testset "DAE Observable function AD" begin
model = create_model()
sys = structural_simplify(model)

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Rodas4())

gs, = gradient(sol) do sol
sum(sol[sys.ampermeter.i])
end
du_ = [0.2, 1.0]
du = [du_ for _ in sol.u]
@test gs == du
end

# @testset "Adjoints with DAE" begin
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
# new_prob = remake(prob, p = new_p)
# sol = solve(new_prob, Rodas4())
# @show size(sol)
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
# sum(sol[sys.ampermeter.i])
# end
#
# @test isnothing(gs_mtkp)
# @test length(gs_p_new) == length(p_new)
# end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ end
@time @safetestset "Partial Functions" begin
include("downstream/partial_functions.jl")
end
@time @safetestset "Autodiff Observable Functions" begin
include("downstream/observables_autodiff.jl")
end
end

if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")
Expand Down

0 comments on commit 3811745

Please sign in to comment.