Skip to content

Commit

Permalink
reapply formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnoStrouwen committed Feb 24, 2024
1 parent 7e42d0e commit 8281939
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 32 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ rng = Random.default_rng()
Random.seed!(rng, seed)

model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)), NewtonRaphson()))
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
NewtonRaphson()))

gdev = gpu_device()
cdev = cpu_device()
Expand Down
4 changes: 2 additions & 2 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ pages = [
"Home" => "index.md",
"Tutorials" => [
"tutorials/basic_mnist_deq.md",
"tutorials/reduced_dim_deq.md",
"tutorials/reduced_dim_deq.md"
],
"API References" => "api.md",
"References" => "references.md",
"References" => "references.md"
]
6 changes: 4 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ rng = Random.default_rng()
Random.seed!(rng, seed)
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)), NewtonRaphson()))
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
NewtonRaphson()))
gdev = gpu_device()
cdev = cpu_device()
Expand Down
5 changes: 3 additions & 2 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack

```@example basic_mnist_deq
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
Expand Down Expand Up @@ -65,7 +65,8 @@ function construct_model(solver; model_type::Symbol=:deq)
Conv((4, 4), 64 => 64; stride=2, pad=1))
# The input layer of the DEQ
deq_model = Chain(Parallel(+,
deq_model = Chain(
Parallel(+,
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/reduced_dim_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size.

```@example reduced_dim_mnist
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
Expand Down
8 changes: 4 additions & 4 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase,
Statistics, SteadyStateDiffEq
Statistics, SteadyStateDiffEq

import ChainRulesCore as CRC
import ConcreteStructs: @concrete
import ConstructionBase: constructorof
import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer
import SciMLBase: AbstractNonlinearAlgorithm,
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
import TruncatedStacktraces: @truncate_stacktrace
end

Expand All @@ -23,7 +23,7 @@ include("utils.jl")

# Exports
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
MultiScaleNeuralODE
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
MultiScaleNeuralODE

end
18 changes: 11 additions & 7 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
```julia
julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
julia> model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
Dense(2, 2; use_bias=false)), VCABM3(); verbose=false)
julia> model = DeepEquilibriumNetwork(
Parallel(+, Dense(2, 2; use_bias=false),
Dense(2, 2; use_bias=false)),
VCABM3(); verbose=false)
DeepEquilibriumNetwork(
model = Parallel(
+
Expand Down Expand Up @@ -233,15 +235,17 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
```julia
julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve
julia> main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false),
Dense(4 => 4, tanh; use_bias=false)), Dense(3 => 3, tanh), Dense(2 => 2, tanh),
julia> main_layers = (
Parallel(+, Dense(4 => 4, tanh; use_bias=false),
Dense(4 => 4, tanh; use_bias=false)),
Dense(3 => 3, tanh), Dense(2 => 2, tanh),
Dense(1 => 1, tanh))
(Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast))
julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
4×4 Matrix{LuxCore.AbstractExplicitLayer}:
NoOpLayer() … Dense(4 => 1, tanh_fast)
Dense(3 => 4, tanh_fast) Dense(3 => 1, tanh_fast)
Expand Down
28 changes: 16 additions & 12 deletions test/layers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq,
SciMLSensitivity, SciMLBase, Test
SciMLSensitivity, SciMLBase, Test

include("test_utils.jl")

Expand All @@ -16,7 +16,7 @@ end

base_models = [
Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)),
Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1)),
Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1))
]
init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)]
x_sizes = [(2, 14), (3, 3, 1, 3)]
Expand All @@ -31,7 +31,8 @@ end
@testset "Solver: $(__nameof(solver))" for solver in solvers,
mtype in model_type, jacobian_regularization in jacobian_regularizations

@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(base_models,
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(
base_models,
init_models, x_sizes)
model = if mtype === :deq
DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
Expand Down Expand Up @@ -86,20 +87,20 @@ end

main_layers = [
(Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)),
__get_dense_layer(3 => 3), __get_dense_layer(2 => 2),
__get_dense_layer(1 => 1)),
__get_dense_layer(3 => 3), __get_dense_layer(2 => 2),
__get_dense_layer(1 => 1))
]

mapping_layers = [
[NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1);
__get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1);
__get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1);
__get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()],
__get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1);
__get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1);
__get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()]
]

init_layers = [
(__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), __get_dense_layer(4 => 2),
__get_dense_layer(4 => 1)),
__get_dense_layer(4 => 1))
]

x_sizes = [(4, 3)]
Expand All @@ -113,16 +114,19 @@ end

for mtype in model_type, jacobian_regularization in jacobian_regularizations
@testset "Solver: $(__nameof(solver))" for solver in solvers
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(
main_layers,
mapping_layers, init_layers, x_sizes, scales)
model = if mtype === :deq
MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
solver, scale; jacobian_regularization)
elseif mtype === :skipdeq
MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
MultiScaleSkipDeepEquilibriumNetwork(
main_layer, mapping_layer, nothing,
init_layer, solver, scale; jacobian_regularization)
elseif mtype === :skipregdeq
MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
MultiScaleSkipDeepEquilibriumNetwork(
main_layer, mapping_layer, nothing,
solver, scale; jacobian_regularization)
elseif mtype === :node
solver isa SciMLBase.AbstractODEAlgorithm || continue
Expand Down

0 comments on commit 8281939

Please sign in to comment.