Skip to content

Commit

Permalink
fix: update Optimization tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 24, 2024
1 parent 5cb86b3 commit 8b717dc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
10 changes: 4 additions & 6 deletions examples/OptimizationIntegration/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -16,14 +15,13 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
[compat]
CairoMakie = "0.12.10"
ComponentArrays = "0.15.17"
IterTools = "1.10"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
Optimization = "3.28.0"
OptimizationOptimJL = "0.3.2"
OptimizationOptimisers = "0.2.1"
OrdinaryDiffEqTsit5 = "1.1.0"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3.1"
OrdinaryDiffEqTsit5 = "1.1"
Printf = "1.10"
Random = "1.10"
SciMLSensitivity = "7.67.0"
Expand Down
23 changes: 9 additions & 14 deletions examples/OptimizationIntegration/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# ## Imports packages

using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5,
SciMLSensitivity, Random, MLUtils, IterTools, CairoMakie, ComponentArrays, Printf
SciMLSensitivity, Random, MLUtils, CairoMakie, ComponentArrays, Printf
using LuxCUDA

const gdev = gpu_device()
Expand Down Expand Up @@ -98,7 +98,7 @@ function train_model(dataloader)

smodel = StatefulLuxLayer{true}(model, nothing, st)

function loss_adjoint(θ, u_batch, t_batch)
function loss_adjoint(θ, (u_batch, t_batch))
t_batch = t_batch.t
u0 = u_batch[:, 1]
dudt(u, p, t) = smodel(u, p)
Expand All @@ -110,26 +110,21 @@ function train_model(dataloader)
## Define the Optimization Function that takes in the optimization state (our parameters)
## and optimization parameters (nothing in our case) and data from the dataloader and
## returns the loss.
opt_func = OptimizationFunction(
(θ, _, u_batch, t_batch) -> loss_adjoint(θ, u_batch, t_batch),
Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_ca)
opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)

nepcohs = 25
res_adam = solve(
opt_prob, Optimisers.Adam(0.001), ncycle(dataloader, nepcohs); callback)
epochs = 25
res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, maxiters=epochs)

## Let's finetune a bit with L-BFGS
opt_prob = remake(opt_prob; u0=res_adam.u)
res_lbfgs = solve(opt_prob, LBFGS(), ncycle(dataloader, nepcohs); callback)
res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs)

## Now that we have a good fit, let's train it on the entire dataset without
## Minibatching. We need to do this since ODE solves can lead to accumulated errors if
## the model was trained on individual parts (without a data-shooting approach).
opt_func = OptimizationFunction(
(θ, _) -> loss_adjoint(θ, gdev(ode_data), TimeWrapper(t)),
Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, res_lbfgs.u)
opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, res_lbfgs.u, (gdev(ode_data), TimeWrapper(t)))

res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)

Expand Down

0 comments on commit 8b717dc

Please sign in to comment.