From 8281939de2d911c16f4d9814d72863a607336f97 Mon Sep 17 00:00:00 2001 From: Arno Strouwen Date: Sat, 24 Feb 2024 02:07:58 +0100 Subject: [PATCH] reapply formatter --- README.md | 6 ++++-- docs/pages.jl | 4 ++-- docs/src/index.md | 6 ++++-- docs/src/tutorials/basic_mnist_deq.md | 5 +++-- docs/src/tutorials/reduced_dim_deq.md | 2 +- src/DeepEquilibriumNetworks.jl | 8 ++++---- src/layers.jl | 18 ++++++++++------- test/layers.jl | 28 +++++++++++++++------------ 8 files changed, 45 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 4af5cc48..3587db8c 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,10 @@ rng = Random.default_rng() Random.seed!(rng, seed) model = Chain(Dense(2 => 2), - DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false), - Dense(2 => 2; use_bias=false)), NewtonRaphson())) + DeepEquilibriumNetwork( + Parallel(+, Dense(2 => 2; use_bias=false), + Dense(2 => 2; use_bias=false)), + NewtonRaphson())) gdev = gpu_device() cdev = cpu_device() diff --git a/docs/pages.jl b/docs/pages.jl index 195f04e3..ac42f48d 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -2,8 +2,8 @@ pages = [ "Home" => "index.md", "Tutorials" => [ "tutorials/basic_mnist_deq.md", - "tutorials/reduced_dim_deq.md", + "tutorials/reduced_dim_deq.md" ], "API References" => "api.md", - "References" => "references.md", + "References" => "references.md" ] diff --git a/docs/src/index.md b/docs/src/index.md index ad042508..0cb693d7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -25,8 +25,10 @@ rng = Random.default_rng() Random.seed!(rng, seed) model = Chain(Dense(2 => 2), - DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false), - Dense(2 => 2; use_bias=false)), NewtonRaphson())) + DeepEquilibriumNetwork( + Parallel(+, Dense(2 => 2; use_bias=false), + Dense(2 => 2; use_bias=false)), + NewtonRaphson())) gdev = gpu_device() cdev = cpu_device() diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 994cb21f..9e1085a8 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack ```@example basic_mnist_deq using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview @@ -65,7 +65,8 @@ function construct_model(solver; model_type::Symbol=:deq) Conv((4, 4), 64 => 64; stride=2, pad=1)) # The input layer of the DEQ - deq_model = Chain(Parallel(+, + deq_model = Chain( + Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()), Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())), Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())) diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index 20242047..0b00b9e1 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size. ```@example reduced_dim_mnist using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview diff --git a/src/DeepEquilibriumNetworks.jl b/src/DeepEquilibriumNetworks.jl index 8a1a2eaa..c7fedef5 100644 --- a/src/DeepEquilibriumNetworks.jl +++ b/src/DeepEquilibriumNetworks.jl @@ -4,14 +4,14 @@ import PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase, - Statistics, SteadyStateDiffEq + Statistics, SteadyStateDiffEq import ChainRulesCore as CRC import ConcreteStructs: @concrete import ConstructionBase: constructorof import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer import SciMLBase: AbstractNonlinearAlgorithm, - AbstractODEAlgorithm, _unwrap_val, NonlinearSolution + AbstractODEAlgorithm, _unwrap_val, NonlinearSolution import TruncatedStacktraces: @truncate_stacktrace end @@ -23,7 +23,7 @@ include("utils.jl") # Exports export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, - MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork, - MultiScaleNeuralODE + MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork, + MultiScaleNeuralODE end diff --git a/src/layers.jl b/src/layers.jl index 730f2f42..935a3fa2 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -152,8 +152,10 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] ```julia julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq -julia> model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false), - Dense(2, 2; use_bias=false)), VCABM3(); verbose=false) +julia> model = DeepEquilibriumNetwork( + Parallel(+, Dense(2, 2; use_bias=false), + Dense(2, 2; use_bias=false)), + VCABM3(); verbose=false) DeepEquilibriumNetwork( model = Parallel( + @@ -233,15 +235,17 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref). ```julia julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve -julia> main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false), - Dense(4 => 4, tanh; use_bias=false)), Dense(3 => 3, tanh), Dense(2 => 2, tanh), +julia> main_layers = ( + Parallel(+, Dense(4 => 4, tanh; use_bias=false), + Dense(4 => 4, tanh; use_bias=false)), + Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh)) (Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast)) julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh); - Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh); - Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh); - Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()] + Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh); + Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh); + Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()] 4×4 Matrix{LuxCore.AbstractExplicitLayer}: NoOpLayer() … Dense(4 => 1, tanh_fast) Dense(3 => 4, tanh_fast) Dense(3 => 1, tanh_fast) diff --git a/test/layers.jl b/test/layers.jl index 7d8f788c..24dcf798 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -1,5 +1,5 @@ using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq, - SciMLSensitivity, SciMLBase, Test + SciMLSensitivity, SciMLBase, Test include("test_utils.jl") @@ -16,7 +16,7 @@ end base_models = [ Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)), - Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1)), + Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1)) ] init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)] x_sizes = [(2, 14), (3, 3, 1, 3)] @@ -31,7 +31,8 @@ end @testset "Solver: $(__nameof(solver))" for solver in solvers, mtype in model_type, jacobian_regularization in jacobian_regularizations - @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(base_models, + @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip( + base_models, init_models, x_sizes) model = if mtype === :deq DeepEquilibriumNetwork(base_model, solver; jacobian_regularization) @@ -86,20 +87,20 @@ end main_layers = [ (Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)), - __get_dense_layer(3 => 3), __get_dense_layer(2 => 2), - __get_dense_layer(1 => 1)), + __get_dense_layer(3 => 3), __get_dense_layer(2 => 2), + __get_dense_layer(1 => 1)) ] mapping_layers = [ [NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1); - __get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1); - __get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1); - __get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()], + __get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1); + __get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1); + __get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()] ] init_layers = [ (__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), __get_dense_layer(4 => 2), - __get_dense_layer(4 => 1)), + __get_dense_layer(4 => 1)) ] x_sizes = [(4, 3)] @@ -113,16 +114,19 @@ end for mtype in model_type, jacobian_regularization in jacobian_regularizations @testset "Solver: $(__nameof(solver))" for solver in solvers - @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers, + @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip( + main_layers, mapping_layers, init_layers, x_sizes, scales) model = if mtype === :deq MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, solver, scale; jacobian_regularization) elseif mtype === :skipdeq - MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, + MultiScaleSkipDeepEquilibriumNetwork( + main_layer, mapping_layer, nothing, init_layer, solver, scale; jacobian_regularization) elseif mtype === :skipregdeq - MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, + MultiScaleSkipDeepEquilibriumNetwork( + main_layer, mapping_layer, nothing, solver, scale; jacobian_regularization) elseif mtype === :node solver isa SciMLBase.AbstractODEAlgorithm || continue