Skip to content

Commit

Permalink
Use ReTestItems for parallel testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent 5e9b8e3 commit 00092b0
Show file tree
Hide file tree
Showing 25 changed files with 423 additions and 392 deletions.
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ whitespace_in_kwargs = false
format_docstrings = true
separate_kwargs_with_semicolon = true
format_markdown = true
annotate_untyped_fields_with_any = false
annotate_untyped_fields_with_any = false
join_lines_based_on_source = false
2 changes: 2 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,7 @@ steps:
timeout_in_minutes: 240

env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA=="
SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm"
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ jobs:
env:
GROUP: "CPU"
JULIA_NUM_THREADS: 12
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down
File renamed without changes.
53 changes: 41 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,49 +1,78 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <avikpal@mit.edu>"]
version = "2.0.3"
version = "2.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.5"
ADTypes = "0.2.5, 1"
Aqua = "0.8.7"
ChainRulesCore = "1"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
ExplicitImports = "1.4.1"
FastClosures = "0.3"
LinearAlgebra = "1"
Functors = "0.4.10"
LinearSolve = "2.21.2"
Lux = "0.5.11"
Lux = "0.5.37"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxTestUtils = "0.1.15"
NLsolve = "4.5.1"
NonlinearSolve = "3.10.0"
OrdinaryDiffEq = "6.74.1"
PrecompileTools = "1"
Random = "1"
Random = "1.10"
ReTestItems = "1.23.1"
SciMLBase = "2"
SciMLSensitivity = "7.43"
Statistics = "1"
StableRNGs = "1.0.2"
Statistics = "1.10"
SteadyStateDiffEq = "2"
TruncatedStacktraces = "1.1"
Zygote = "0.6.67"
julia = "1.9"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ Random.seed!(rng, seed)

model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
NewtonRaphson()))

gdev = gpu_device()
Expand Down
11 changes: 8 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear)

include("pages.jl")

makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks], clean=true, doctest=true, linkcheck=true,
makedocs(; sitename="Deep Equilibrium Networks",
authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks],
clean=true,
doctest=true,
linkcheck=true,
format=Documenter.HTML(; assets=["assets/favicon.ico"],
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
plugins=[bib], pages)
plugins=[bib],
pages)

deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)
12 changes: 3 additions & 9 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
pages = [
"Home" => "index.md",
"Tutorials" => [
"tutorials/basic_mnist_deq.md",
"tutorials/reduced_dim_deq.md"
],
"API References" => "api.md",
"References" => "references.md"
]
pages = ["Home" => "index.md",
"Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"],
"API References" => "api.md", "References" => "references.md"]
3 changes: 1 addition & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ Random.seed!(rng, seed)
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
NewtonRaphson()))
gdev = gpu_device()
Expand Down
15 changes: 7 additions & 8 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ function construct_model(solver; model_type::Symbol=:deq)
# The input layer of the DEQ
deq_model = Chain(
Parallel(+,
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
Expand All @@ -79,11 +78,11 @@ function construct_model(solver; model_type::Symbol=:deq)
init = missing
end
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10))
deq = DeepEquilibriumNetwork(
deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
Dense(64, 10))
classifier = Chain(
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
model = Chain(; down, deq, classifier)
Expand Down Expand Up @@ -132,8 +131,8 @@ function accuracy(model, data, ps, st)
return total_correct / total
end
function train_model(solver, model_type; data_train=zip(x_train, y_train),
data_test=zip(x_test, y_test))
function train_model(
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
Expand Down
15 changes: 7 additions & 8 deletions docs/src/tutorials/reduced_dim_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ function construct_model(solver; model_type::Symbol=:regdeq)
down = Chain(FlattenLayer(), Dense(784 => 512, gelu))
# The input layer of the DEQ
deq_model = Chain(Parallel(+,
Dense(128 => 64, tanh), # Reduced dim of `128`
deq_model = Chain(Parallel(+, Dense(128 => 64, tanh), # Reduced dim of `128`
Dense(512 => 64, tanh)), # Original dim of `512`
Dense(64 => 64, tanh), Dense(64 => 128)) # Return the reduced dim of `128`
Expand All @@ -65,12 +64,12 @@ function construct_model(solver; model_type::Symbol=:regdeq)
else
# This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
# we are only using Zygote so this is fine.
init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)),
false)))
init = WrappedFunction(x -> Zygote.@ignore(fill!(
similar(x, 128, size(x, 2)), false)))
end
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10))
deq = DeepEquilibriumNetwork(
deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))
classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10))
Expand Down Expand Up @@ -121,8 +120,8 @@ function accuracy(model, data, ps, st)
return total_correct / total
end
function train_model(solver, model_type; data_train=zip(x_train, y_train),
data_test=zip(x_test, y_test))
function train_model(
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt
module DeepEquilibriumNetworksSciMLSensitivityExt

# Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity
# to load this extension
using LinearSolve, SciMLBase, SciMLSensitivity
import DeepEquilibriumNetworks: __default_sensealg
using LinearSolve: SimpleGMRES
using SciMLBase: SteadyStateProblem, ODEProblem
using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP
using DeepEquilibriumNetworks: DEQs

@inline function __default_sensealg(prob::SteadyStateProblem)
@inline function DEQs.__default_sensealg(prob::SteadyStateProblem)
# We want to avoid the cost for cache construction for linsolve = nothing
# For small problems we should use concrete jacobian but we assume users want to solve
# large problems with this package so we default to GMRES and avoid runtime dispatches
linsolve = SimpleGMRES{true}(; blocksize=prod(size(prob.u0)[1:(end - 1)]))
linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)
return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP())
end
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
@inline DEQs.__default_sensealg(prob::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())

end
17 changes: 12 additions & 5 deletions ext/DeepEquilibriumNetworksZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
module DeepEquilibriumNetworksZygoteExt

using ADTypes, Statistics, Zygote
import DeepEquilibriumNetworks: __gaussian_like, __estimate_jacobian_trace
using ADTypes: AutoZygote
using FastClosures: @closure
using Statistics: mean
using Zygote: Zygote
using DeepEquilibriumNetworks: DEQs

function __estimate_jacobian_trace(::AutoZygote, model, ps, z, x, rng)
res, back = Zygote.pullback(u -> model((u, x), ps), z)
vjp_z = only(back(__gaussian_like(rng, res)))
@inline __tupleify(u) = @closure x -> (u, x)

## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
## FIXME: This will be broken in the new Lux release let's fix this
function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng)
res, back = Zygote.pullback(model __tupleify, z)
vjp_z = only(back(DEQs.__gaussian_like(rng, res)))
return mean(abs2, vjp_z)
end

Expand Down
26 changes: 16 additions & 10 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@ module DeepEquilibriumNetworks
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase,
Statistics, SteadyStateDiffEq

import ChainRulesCore as CRC
import ConcreteStructs: @concrete
import ConstructionBase: constructorof
import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer
import SciMLBase: AbstractNonlinearAlgorithm,
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
import TruncatedStacktraces: @truncate_stacktrace
using ADTypes: AutoFiniteDiff
using ChainRulesCore: ChainRulesCore
using CommonSolve: solve
using ConcreteStructs: @concrete
using ConstructionBase: ConstructionBase
using DiffEqBase: DiffEqBase, AbsNormTerminationMode
using FastClosures: @closure
using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer,
StatefulLuxLayer, WrappedFunction
using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer
using Random: Random, AbstractRNG, randn!
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm,
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
SteadyStateProblem, _unwrap_val
using SteadyStateDiffEq: DynamicSS, SSRootfind
end

# Useful Constants
const CRC = ChainRulesCore
const DEQs = DeepEquilibriumNetworks

include("layers.jl")
Expand Down
Loading

0 comments on commit 00092b0

Please sign in to comment.