diff --git a/examples/OptimizationIntegration/Project.toml b/examples/OptimizationIntegration/Project.toml index 9a4fdfec0..11691aa17 100644 --- a/examples/OptimizationIntegration/Project.toml +++ b/examples/OptimizationIntegration/Project.toml @@ -1,7 +1,6 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -16,14 +15,13 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" [compat] CairoMakie = "0.12.10" ComponentArrays = "0.15.17" -IterTools = "1.10" Lux = "1" LuxCUDA = "0.3.3" MLUtils = "0.4.4" -Optimization = "3.28.0" -OptimizationOptimJL = "0.3.2" -OptimizationOptimisers = "0.2.1" -OrdinaryDiffEqTsit5 = "1.1.0" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3.1" +OrdinaryDiffEqTsit5 = "1.1" Printf = "1.10" Random = "1.10" SciMLSensitivity = "7.67.0" diff --git a/examples/OptimizationIntegration/main.jl b/examples/OptimizationIntegration/main.jl index b2ba1c1bc..7c617348a 100644 --- a/examples/OptimizationIntegration/main.jl +++ b/examples/OptimizationIntegration/main.jl @@ -17,7 +17,7 @@ # ## Imports packages using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5, - SciMLSensitivity, Random, MLUtils, IterTools, CairoMakie, ComponentArrays, Printf + SciMLSensitivity, Random, MLUtils, CairoMakie, ComponentArrays, Printf using LuxCUDA const gdev = gpu_device() @@ -98,7 +98,7 @@ function train_model(dataloader) smodel = StatefulLuxLayer{true}(model, nothing, st) - function loss_adjoint(θ, u_batch, t_batch) + function loss_adjoint(θ, (u_batch, t_batch)) t_batch = t_batch.t u0 = u_batch[:, 1] dudt(u, p, t) = smodel(u, p) @@ -110,26 +110,21 @@ function train_model(dataloader) ## Define the Optimization Function that takes in the optimization state (our parameters) ## and optimization parameters (nothing in our case) and data from the dataloader and ## returns the loss. - opt_func = OptimizationFunction( - (θ, _, u_batch, t_batch) -> loss_adjoint(θ, u_batch, t_batch), - Optimization.AutoZygote()) - opt_prob = OptimizationProblem(opt_func, ps_ca) + opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote()) + opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader) - nepcohs = 25 - res_adam = solve( - opt_prob, Optimisers.Adam(0.001), ncycle(dataloader, nepcohs); callback) + epochs = 25 + res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, maxiters=epochs) ## Let's finetune a bit with L-BFGS opt_prob = remake(opt_prob; u0=res_adam.u) - res_lbfgs = solve(opt_prob, LBFGS(), ncycle(dataloader, nepcohs); callback) + res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs) ## Now that we have a good fit, let's train it on the entire dataset without ## Minibatching. We need to do this since ODE solves can lead to accumulated errors if ## the model was trained on individual parts (without a data-shooting approach). - opt_func = OptimizationFunction( - (θ, _) -> loss_adjoint(θ, gdev(ode_data), TimeWrapper(t)), - Optimization.AutoZygote()) - opt_prob = OptimizationProblem(opt_func, res_lbfgs.u) + opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote()) + opt_prob = OptimizationProblem(opt_func, res_lbfgs.u, (gdev(ode_data), TimeWrapper(t))) res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)