diff --git a/Project.toml b/Project.toml index d21d941c..7b480d9a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaTimeSteppers" uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" authors = ["Climate Modeling Alliance"] -version = "0.7.8" +version = "0.7.9" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 302d1d08..a54cfb68 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -45,11 +45,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unco return IMEXARKCache(U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) end -function step_u!(integrator, cache::IMEXARKCache) +step_u!(integrator, cache::IMEXARKCache) = step_u!(integrator, cache, integrator.alg.name) + +function step_u!(integrator, cache::IMEXARKCache, name) (; u, p, t, dt, sol, alg) = integrator (; f) = sol.prob (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f - (; name, tableau, newtons_method) = alg + (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache s = length(b_exp) diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 1b52bcd2..ec581dc4 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -52,11 +52,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSP} return IMEXSSPRKCache(U, U_exp, U_lim, T_lim, T_exp, T_imp, temp, β, γ, newtons_method_cache) end -function step_u!(integrator, cache::IMEXSSPRKCache) +step_u!(integrator, cache::IMEXSSPRKCache) = step_u!(integrator, cache, integrator.alg.name) + +function step_u!(integrator, cache::IMEXSSPRKCache, name) (; u, p, t, dt, sol, alg) = integrator (; f) = sol.prob (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f - (; name, tableau, newtons_method) = alg + (; tableau, newtons_method) = alg (; a_imp, b_imp, c_exp, c_imp) = tableau (; U, U_lim, U_exp, T_lim, T_exp, T_imp, temp, β, γ, newtons_method_cache) = cache s = length(b_imp)