From 92e8469b94f5c2f4a38f6d5c43fd987c6a4accc7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Sep 2024 13:03:46 -0400 Subject: [PATCH] fix: update GravitationalWaveform tutorial --- examples/GravitationalWaveForm/Project.toml | 8 ++++---- examples/GravitationalWaveForm/main.jl | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index 67e420f41..004c9f661 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -7,7 +7,7 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" @@ -18,7 +18,7 @@ ComponentArrays = "0.15" LineSearches = "7" Literate = "2" Lux = "1" -Optimization = "3" -OptimizationOptimJL = "0.3" -OrdinaryDiffEq = "6" +Optimization = "4" +OptimizationOptimJL = "0.4" +OrdinaryDiffEqLowOrderRK = "1" SciMLSensitivity = "7.57" diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 56bbb2301..bbec37bc4 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -7,8 +7,8 @@ # ## Package Imports -using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL, - Printf, Random, SciMLSensitivity +using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization, + OptimizationOptimJL, Printf, Random, SciMLSensitivity using CairoMakie # ## Define some Utility Functions @@ -221,16 +221,16 @@ end # We will deviate from the standard Neural Network initialization and use # `WeightInitializers.jl`, -const nn = Chain(Base.Fix1(broadcast, cos), - Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)), - Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)), - Dense(32 => 2; init_weight=truncated_normal(; std=1e-4))) -ps, st = Lux.setup(Xoshiro(), nn) +const nn = Chain(Base.Fix1(fast_activation, cos), + Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32), + Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32), + Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32)) +ps, st = Lux.setup(Random.default_rng(), nn) # Similar to most DL frameworks, Lux defaults to using `Float32`, however, in this case we # need Float64 -const params = ComponentArray{Float64}(ps) +const params = ComponentArray(ps |> f64) const nn_model = StatefulLuxLayer{true}(nn, nothing, st) @@ -293,7 +293,7 @@ const mseloss = MSELoss() function loss(θ) pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false)) pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params)) - return mseloss(waveform, pred_waveform), pred_waveform + return mseloss(pred_waveform, waveform), pred_waveform end # Warmup the loss function