Skip to content

Commit

Permalink
Non-MM ndae updates (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Dec 21, 2020
1 parent bf6c588 commit f2ccfb9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
19 changes: 10 additions & 9 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,20 +357,21 @@ end

function (n::NeuralDAE)(x,du0=n.du0,p=n.p)
function f(du,u,p,t)
nn_out = n.re(p)(u)
nn_out = n.re(p)(vcat(u,du))
alg_out = n.constraints_model(u,p,t)
v_out = []
for (j,i) in enumerate(n.differential_vars)
if i
push!(v_out,nn_out[j])
iter_nn = 0
iter_consts = 0
map(n.differential_vars) do isdiff
if isdiff
iter_nn += 1
nn_out[iter_nn]
else
push!(v_out,alg_out[j])
iter_consts += 1
alg_out[iter_consts]
end
end
return v_out
end
dudt_(du,u,p,t) = f
prob = DAEProblem(dudt_,du0,x,n.tspan,p,differential_vars=n.differential_vars)
prob = DAEProblem{false}(f,du0,x,n.tspan,p,differential_vars=n.differential_vars)
solve(prob,n.args...;sensalg=TrackerAdjoint(),n.kwargs...)
end

Expand Down
41 changes: 19 additions & 22 deletions test/neural_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@ using Flux, DiffEqFlux, OrdinaryDiffEq, GalacticOptim

#A desired MWE for now, not a test yet.

function f(du,u,p,t)
y₁,y₂,y₃ = u
k₁,k₂,k₃ = p
du[1] = -k₁*y₁ + k₃*y₂*y₃
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
du[3] = y₁ + y₂ + y₃ - 1
nothing
function rober(du,u,p,t)
y₁,y₂,y₃ = u
k₁,k₂,k₃ = p
du[1] = -k₁*y₁ + k₃*y₂*y₃
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
du[3] = y₁ + y₂ + y₃ - 1
nothing
end

u₀ = [1.0, 0, 0]
M = [1. 0 0
0 1. 0
0 0 0]
tspan = (0.0,10.0)
p = [0.04,3e7,1e4]
func = ODEFunction(f,mass_matrix=M)
prob = ODEProblem(f,u₀,tspan,(0.04,3e7,1e4))
sol = solve(prob,Rodas5())

0 1. 0
0 0 0]
prob_mm = ODEProblem(ODEFunction(rober,mass_matrix=M),[1.0,0.0,0.0],(0.0,10.0),(0.04,3e7,1e4))
sol = solve(prob_mm,Rodas5(),reltol=1e-8,abstol=1e-8)

dudt2 = Chain(x -> x.^3,Dense(3,50,tanh),Dense(50,2))

dudt2 = Chain(x -> x.^3,Dense(6,50,tanh),Dense(50,2))

ndae = NeuralDAE(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5())
ndae = NeuralDAE(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, DImplicitEuler(),
differential_vars = [true,true,false])
truedu0 = similar(u₀)
f(truedu0,u₀,p,0.0)

ndae(u₀)
ndae(u₀,truedu0,Float64.(ndae.p))

function predict_n_dae(p)
ndae(u₀,p)
Expand All @@ -39,8 +36,8 @@ function loss(p)
loss,pred
end

p = p .+ rand(3) .* p
p = p .+ rand(3) .* p

optfunc = GalacticOptim.OptimizationFunction((x, p) -> loss(x), GalacticOptim.AutoZygote())
optprob = GalacticOptim.OptimizationProblem(optfunc, p)
res = GalacticOptim.solve(optprob, BFGS(initial_stepnorm = 0.0001))
res = GalacticOptim.solve(optprob, BFGS(initial_stepnorm = 0.0001))

0 comments on commit f2ccfb9

Please sign in to comment.