diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index f632a90a..35969990 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -3,4 +3,5 @@ whitespace_in_kwargs = false format_docstrings = true separate_kwargs_with_semicolon = true format_markdown = true -annotate_untyped_fields_with_any = false \ No newline at end of file +annotate_untyped_fields_with_any = false +join_lines_based_on_source = false diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f1f11a31..eb6978be 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -54,5 +54,7 @@ steps: timeout_in_minutes: 240 env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA==" SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 43cb34a4..c7df47bf 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -43,6 +43,8 @@ jobs: env: GROUP: "CPU" JULIA_NUM_THREADS: 12 + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/test/LocalPreferences.toml b/LocalPreferences.toml similarity index 100% rename from test/LocalPreferences.toml rename to LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 6baa53da..0f57e746 100644 --- a/Project.toml +++ b/Project.toml @@ -1,49 +1,80 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "2.0.3" +version = "2.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] -DeepEquilibriumNetworksZygoteExt = "Zygote" +DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] +DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"] [compat] -ADTypes = "0.2.5" +ADTypes = "0.2.5, 1" +Aqua = "0.8.7" ChainRulesCore = "1" +CommonSolve = "0.2.4" ConcreteStructs = "0.2" ConstructionBase = "1" DiffEqBase = "6.119" +ExplicitImports = "1.4.1" FastClosures = "0.3" -LinearAlgebra = "1" +ForwardDiff = "0.10.36" +Functors = "0.4.10" LinearSolve = "2.21.2" -Lux = "0.5.11" +Lux = "0.5.38" +LuxCUDA = "0.3.2" +LuxCore = "0.1.14" +LuxTestUtils = "0.1.15" +NLsolve = "4.5.1" +NonlinearSolve = "3.10.0" +OrdinaryDiffEq = "6.74.1" PrecompileTools = "1" -Random = "1" +Random = "1.10" +ReTestItems = "1.23.1" SciMLBase = "2" SciMLSensitivity = "7.43" -Statistics = "1" +StableRNGs = "1.0.2" +Statistics = "1.10" SteadyStateDiffEq = "2" -TruncatedStacktraces = "1.1" -Zygote = "0.6.67" -julia = "1.9" +Test = "1.10" +Zygote = "0.6.69" +julia = "1.10" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[targets] +test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"] diff --git a/README.md b/README.md index 3587db8c..876685a2 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,7 @@ Random.seed!(rng, seed) model = Chain(Dense(2 => 2), DeepEquilibriumNetwork( - Parallel(+, Dense(2 => 2; use_bias=false), - Dense(2 => 2; use_bias=false)), + Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)), NewtonRaphson())) gdev = gpu_device() diff --git a/docs/Project.toml b/docs/Project.toml index 428e8bda..79874acd 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,9 @@ [deps] +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" @@ -11,6 +11,7 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -21,7 +22,6 @@ DeepEquilibriumNetworks = "2" Documenter = "1" DocumenterCitations = "1" LinearSolve = "2" -LoggingExtras = "1" Lux = "0.5" LuxCUDA = "0.3" MLDataUtils = "0.5" diff --git a/docs/make.jl b/docs/make.jl index 62440167..3117b436 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,10 +7,15 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear) include("pages.jl") -makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.", - modules=[DeepEquilibriumNetworks], clean=true, doctest=true, linkcheck=true, +makedocs(; sitename="Deep Equilibrium Networks", + authors="Avik Pal et al.", + modules=[DeepEquilibriumNetworks], + clean=true, + doctest=true, + linkcheck=true, format=Documenter.HTML(; assets=["assets/favicon.ico"], canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"), - plugins=[bib], pages) + plugins=[bib], + pages) deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true) diff --git a/docs/pages.jl b/docs/pages.jl index ac42f48d..5a82ffc3 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,9 +1,3 @@ -pages = [ - "Home" => "index.md", - "Tutorials" => [ - "tutorials/basic_mnist_deq.md", - "tutorials/reduced_dim_deq.md" - ], - "API References" => "api.md", - "References" => "references.md" -] +pages = ["Home" => "index.md", + "Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"], + "API References" => "api.md", "References" => "references.md"] diff --git a/docs/src/index.md b/docs/src/index.md index 0cb693d7..2cb1befb 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,8 +26,7 @@ Random.seed!(rng, seed) model = Chain(Dense(2 => 2), DeepEquilibriumNetwork( - Parallel(+, Dense(2 => 2; use_bias=false), - Dense(2 => 2; use_bias=false)), + Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)), NewtonRaphson())) gdev = gpu_device() diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 9e1085a8..3684f4a7 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack ```@example basic_mnist_deq using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview @@ -20,18 +20,6 @@ const cdev = cpu_device() const gdev = gpu_device() ``` -SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress -it with the following logger - -```@example basic_mnist_deq -function remove_syms_warning(log_args) - return log_args.message != - "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead." -end - -filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger()) -``` - We can now construct our dataloader. ```@example basic_mnist_deq @@ -66,8 +54,7 @@ function construct_model(solver; model_type::Symbol=:deq) # The input layer of the DEQ deq_model = Chain( - Parallel(+, - Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()), + Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()), Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())), Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())) @@ -79,11 +66,11 @@ function construct_model(solver; model_type::Symbol=:deq) init = missing end - deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, - linsolve_kwargs=(; maxiters=10)) + deq = DeepEquilibriumNetwork( + deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10)) - classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), - Dense(64, 10)) + classifier = Chain( + GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10)) model = Chain(; down, deq, classifier) @@ -95,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq) x = randn(rng, Float32, 28, 28, 1, 128) y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev - model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st) - @info "warming up forward pass" + model_ = StatefulLuxLayer(model, ps, st) + @printf "[%s] warming up forward pass\n" string(now()) logitcrossentropy(model_, x, ps, y) - @info "warming up backward pass" + @printf "[%s] warming up backward pass\n" string(now()) Zygote.gradient(logitcrossentropy, model_, x, ps, y) - @info "warmup complete" + @printf "[%s] warmup complete\n" string(now()) return model, ps, st end @@ -122,7 +109,7 @@ classify(x) = argmax.(eachcol(x)) function accuracy(model, data, ps, st) total_correct, total = 0, 0 st = Lux.testmode(st) - model = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model = StatefulLuxLayer(model, ps, st) for (x, y) in data target_class = classify(cdev(y)) predicted_class = classify(cdev(model(x))) @@ -132,51 +119,48 @@ function accuracy(model, data, ps, st) return total_correct / total end -function train_model(solver, model_type; data_train=zip(x_train, y_train), - data_test=zip(x_test, y_test)) +function train_model( + solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) model, ps, st = construct_model(solver; model_type) - model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) + model_st = StatefulLuxLayer(model, nothing, st) - @info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))" + @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) opt_st = Optimisers.setup(Adam(0.001), ps) acc = accuracy(model, data_test, ps, st) * 100 - @info "Starting Accuracy: $(acc)" + @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc - @info "Pretrain with unrolling to a depth of 5" + @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now()) st = Lux.update_state(st, :fixed_depth, Val(5)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Pretraining complete. Accuracy: $(acc)" + @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc st = Lux.update_state(st, :fixed_depth, Val(0)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for epoch in 1:3 for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Epoch: [$(epoch)/3] Accuracy: $(acc)" + @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc end - @info "Training complete." - println() + @printf "[%s] Training complete.\n" string(now()) return model, ps, st end @@ -188,9 +172,7 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa from NonlinearSolve.jl. Here we will use Newton-Krylov Method: ```@example basic_mnist_deq -with_logger(filtered_logger) do - train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq) -end +train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq); nothing # hide ``` @@ -198,9 +180,7 @@ We can also train a continuous DEQ by passing in an ODE solver. Here we will use which tend to be quite fast for continuous Neural Network problems. ```@example basic_mnist_deq -with_logger(filtered_logger) do - train_model(VCAB3(), :deq) -end +train_model(VCAB3(), :deq); nothing # hide ``` diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index 0b00b9e1..9f72ac69 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size. ```@example reduced_dim_mnist using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview @@ -16,13 +16,6 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true const cdev = cpu_device() const gdev = gpu_device() -function remove_syms_warning(log_args) - return log_args.message != - "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead." -end - -filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger()) - function onehot(labels_raw) return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) end @@ -53,8 +46,7 @@ function construct_model(solver; model_type::Symbol=:regdeq) down = Chain(FlattenLayer(), Dense(784 => 512, gelu)) # The input layer of the DEQ - deq_model = Chain(Parallel(+, - Dense(128 => 64, tanh), # Reduced dim of `128` + deq_model = Chain(Parallel(+, Dense(128 => 64, tanh), # Reduced dim of `128` Dense(512 => 64, tanh)), # Original dim of `512` Dense(64 => 64, tanh), Dense(64 => 128)) # Return the reduced dim of `128` @@ -65,12 +57,12 @@ function construct_model(solver; model_type::Symbol=:regdeq) else # This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here # we are only using Zygote so this is fine. - init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)), - false))) + init = WrappedFunction(x -> Zygote.@ignore(fill!( + similar(x, 128, size(x, 2)), false))) end - deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, - linsolve_kwargs=(; maxiters=10)) + deq = DeepEquilibriumNetwork( + deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10)) classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10)) @@ -84,12 +76,12 @@ function construct_model(solver; model_type::Symbol=:regdeq) x = randn(rng, Float32, 28, 28, 1, 128) y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev - model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st) - @info "warming up forward pass" + model_ = StatefulLuxLayer(model, ps, st) + @printf "[%s] warming up forward pass\n" string(now()) logitcrossentropy(model_, x, ps, y) - @info "warming up backward pass" + @printf "[%s] warming up backward pass\n" string(now()) Zygote.gradient(logitcrossentropy, model_, x, ps, y) - @info "warmup complete" + @printf "[%s] warmup complete\n" string(now()) return model, ps, st end @@ -111,7 +103,7 @@ classify(x) = argmax.(eachcol(x)) function accuracy(model, data, ps, st) total_correct, total = 0, 0 st = Lux.testmode(st) - model = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model = StatefulLuxLayer(model, ps, st) for (x, y) in data target_class = classify(cdev(y)) predicted_class = classify(cdev(model(x))) @@ -121,51 +113,48 @@ function accuracy(model, data, ps, st) return total_correct / total end -function train_model(solver, model_type; data_train=zip(x_train, y_train), - data_test=zip(x_test, y_test)) +function train_model( + solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) model, ps, st = construct_model(solver; model_type) - model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) + model_st = StatefulLuxLayer(model, nothing, st) - @info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))" + @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) opt_st = Optimisers.setup(Adam(0.001), ps) acc = accuracy(model, data_test, ps, st) * 100 - @info "Starting Accuracy: $(acc)" + @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc - @info "Pretrain with unrolling to a depth of 5" + @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now()) st = Lux.update_state(st, :fixed_depth, Val(5)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Pretraining complete. Accuracy: $(acc)" + @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc st = Lux.update_state(st, :fixed_depth, Val(0)) - model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + model_st = StatefulLuxLayer(model, ps, st) for epoch in 1:3 for (i, (x, y)) in enumerate(data_train) res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) Optimisers.update!(opt_st, ps, res.grad[3]) - if i % 50 == 1 - @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" - end + i % 50 == 1 && + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val end acc = accuracy(model, data_test, ps, model_st.st) * 100 - @info "Epoch: [$(epoch)/3] Accuracy: $(acc)" + @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc end - @info "Training complete." - println() + @printf "[%s] Training complete.\n" string(now()) return model, ps, st end @@ -175,15 +164,11 @@ Now we can train our model. We can't use `:regdeq` here currently, but we will s in the future. ```@example reduced_dim_mnist -with_logger(filtered_logger) do - train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq) -end +train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq) nothing # hide ``` ```@example reduced_dim_mnist -with_logger(filtered_logger) do - train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq) -end +train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq) nothing # hide ``` diff --git a/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl b/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl similarity index 61% rename from ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl rename to ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl index 21cc34ca..fdc36591 100644 --- a/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl +++ b/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl @@ -1,11 +1,13 @@ -module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt +module DeepEquilibriumNetworksSciMLSensitivityExt # Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity # to load this extension -using LinearSolve, SciMLBase, SciMLSensitivity -import DeepEquilibriumNetworks: __default_sensealg +using LinearSolve: SimpleGMRES +using SciMLBase: SteadyStateProblem, ODEProblem +using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP +using DeepEquilibriumNetworks: DEQs -@inline function __default_sensealg(prob::SteadyStateProblem) +@inline function DEQs.__default_sensealg(prob::SteadyStateProblem) # We want to avoid the cost for cache construction for linsolve = nothing # For small problems we should use concrete jacobian but we assume users want to solve # large problems with this package so we default to GMRES and avoid runtime dispatches @@ -13,6 +15,6 @@ import DeepEquilibriumNetworks: __default_sensealg linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3) return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP()) end -@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) +@inline DEQs.__default_sensealg(prob::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) end diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl index 56fd849c..a04697e0 100644 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ b/ext/DeepEquilibriumNetworksZygoteExt.jl @@ -1,12 +1,57 @@ module DeepEquilibriumNetworksZygoteExt -using ADTypes, Statistics, Zygote -import DeepEquilibriumNetworks: __gaussian_like, __estimate_jacobian_trace +using ADTypes: AutoZygote +using ChainRulesCore: ChainRulesCore +using DeepEquilibriumNetworks: DEQs +using FastClosures: @closure +using ForwardDiff: ForwardDiff # This is a dependency of Zygote +using Lux: Lux, StatefulLuxLayer +using Statistics: mean +using Zygote: Zygote -function __estimate_jacobian_trace(::AutoZygote, model, ps, z, x, rng) - res, back = Zygote.pullback(u -> model((u, x), ps), z) - vjp_z = only(back(__gaussian_like(rng, res))) - return mean(abs2, vjp_z) +const CRC = ChainRulesCore + +@inline __tupleify(x) = @closure(u->(u, x)) + +## One day we will overload DI's APIs for Lux Layers and we can remove this +## Main challenge with overloading Zygote.pullback is that we need to return the correct +## tangent for the pullback to compute the correct gradient, which is quite hard. But +## wrapping the overall vjp is not that hard. +@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify(x), z) + return only(back(DEQs.__gaussian_like(rng, res))) +end + +function CRC.rrule( + ::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng) + res, back = Zygote.pullback(model ∘ __tupleify(x), z) + ε = DEQs.__gaussian_like(rng, res) + y = only(back(ε)) + ∇internal_gradient_capture = Δ -> begin + (Δ isa CRC.NoTangent || Δ isa CRC.ZeroTangent) && + return ntuple(Returns(CRC.NoTangent()), 6) + + Δ_ = reshape(CRC.unthunk(Δ), size(z)) + + Tag = typeof(ForwardDiff.Tag(model, eltype(z))) + partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_)) + z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials) + + _, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps) + ∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε) + + ∂z = Lux.__partials(Tag, ∂z_duals, 1) + ∂x = Lux.__partials(Tag, ∂x_duals, 1) + ∂ps = Lux.__partials(Tag, ∂ps_duals, 1) + + return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent() + end + return y, ∇internal_gradient_capture +end + +## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 +function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) + return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng)) end end diff --git a/src/DeepEquilibriumNetworks.jl b/src/DeepEquilibriumNetworks.jl index c7fedef5..abaccfbb 100644 --- a/src/DeepEquilibriumNetworks.jl +++ b/src/DeepEquilibriumNetworks.jl @@ -3,19 +3,25 @@ module DeepEquilibriumNetworks import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase, - Statistics, SteadyStateDiffEq - - import ChainRulesCore as CRC - import ConcreteStructs: @concrete - import ConstructionBase: constructorof - import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer - import SciMLBase: AbstractNonlinearAlgorithm, - AbstractODEAlgorithm, _unwrap_val, NonlinearSolution - import TruncatedStacktraces: @truncate_stacktrace + using ADTypes: AutoFiniteDiff + using ChainRulesCore: ChainRulesCore + using CommonSolve: solve + using ConcreteStructs: @concrete + using ConstructionBase: ConstructionBase + using DiffEqBase: DiffEqBase, AbsNormTerminationMode + using FastClosures: @closure + using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer, + StatefulLuxLayer, WrappedFunction + using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer + using Random: Random, AbstractRNG, randn! + using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm, + NonlinearSolution, ODESolution, ODEFunction, ODEProblem, + SteadyStateProblem, _unwrap_val + using SteadyStateDiffEq: DynamicSS, SSRootfind end # Useful Constants +const CRC = ChainRulesCore const DEQs = DeepEquilibriumNetworks include("layers.jl") diff --git a/src/layers.jl b/src/layers.jl index 935a3fa2..995f94db 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -22,13 +22,13 @@ struct DeepEquilibriumSolution # This is intentionally left untyped to allow up original end -function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, u0, residual, jacobian_loss, - nfe, original) +function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, + u0, residual, jacobian_loss, nfe, original) sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original) ∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7) function ∇DeepEquilibriumSolution(∂sol) - return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, ∂sol.jacobian_loss, - ∂sol.nfe, CRC.NoTangent()) + return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, + ∂sol.jacobian_loss, ∂sol.nfe, CRC.NoTangent()) end return sol, ∇DeepEquilibriumSolution end @@ -39,10 +39,14 @@ end function Base.show(io::IO, sol::DeepEquilibriumSolution) println(io, "DeepEquilibriumSolution") - println(io, " * Initial Guess: ", sol.u0) - println(io, " * Steady State: ", sol.z_star) - println(io, " * Residual: ", sol.residual) - println(io, " * Jacobian Loss: ", sol.jacobian_loss) + println(io, " * Initial Guess: ", + sprint(print, sol.u0; context=(:compact => true, :limit => true))) + println(io, " * Steady State: ", + sprint(print, sol.z_star; context=(:compact => true, :limit => true))) + println(io, " * Residual: ", + sprint(print, sol.residual; context=(:compact => true, :limit => true))) + println(io, " * Jacobian Loss: ", + sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true))) print(io, " * NFE: ", sol.nfe) end @@ -56,11 +60,9 @@ end kwargs end -@truncate_stacktrace DeepEquilibriumNetwork 3 2 - const DEQ = DeepEquilibriumNetwork -constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType} +ConstructionBase.constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType} function Lux.initialstates(rng::AbstractRNG, deq::DEQ) rng = Lux.replicate(rng) @@ -76,17 +78,17 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true}) z, st = __get_initial_condition(deq, x, ps, st) repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth) - zˢᵗᵃʳ, st_ = repeated_model((z, x), ps.model, st.model) - model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st_) - resid = CRC.ignore_derivatives(zˢᵗᵃʳ .- model((zˢᵗᵃʳ, x), ps.model)) + z_star, st_ = repeated_model((z, x), ps.model, st.model) + model = StatefulLuxLayer(deq.model, ps.model, st_) + resid = CRC.ignore_derivatives(z_star .- model((z_star, x))) rng = Lux.replicate(st.rng) - jac_loss = __estimate_jacobian_trace(__getproperty(deq, Val(:jacobian_regularization)), - model, ps.model, zˢᵗᵃʳ, x, rng) + jac_loss = __estimate_jacobian_trace( + __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) - solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, resid, zero(eltype(x)), - _unwrap_val(st.fixed_depth), jac_loss) - res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)), + solution = DeepEquilibriumSolution( + z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss) + res = __split_and_reshape(z_star, __getproperty(deq.model, Val(:split_idxs)), __getproperty(deq.model, Val(:scales))) return res, (; st..., model=model.st, solution, rng) @@ -95,7 +97,7 @@ end function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} z, st = __get_initial_condition(deq, x, ps, st) - model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st.model) + model = StatefulLuxLayer(deq.model, ps.model, st.model) dudt = @closure (u, p, t) -> begin # The type-assert is needed because of an upstream Lux issue with type stability of @@ -106,17 +108,18 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x)) alg = __normalize_alg(deq) - sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, reltol=1e-3, - termination_condition=AbsNormTerminationMode(), maxiters=32, deq.kwargs...) - zˢᵗᵃʳ = __get_steady_state(sol) + termination_condition = AbsNormTerminationMode(Base.Fix1(maximum, abs)) + sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, + reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...) + z_star = __get_steady_state(sol) rng = Lux.replicate(st.rng) - jac_loss = __estimate_jacobian_trace(__getproperty(deq, Val(:jacobian_regularization)), - model, ps.model, zˢᵗᵃʳ, x, rng) + jac_loss = __estimate_jacobian_trace( + __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) - solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, __getproperty(sol, Val(:resid)), jac_loss, - __get_nfe(sol), sol) - res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)), + solution = DeepEquilibriumSolution( + z_star, z, __getproperty(sol, Val(:resid)), jac_loss, __get_nfe(sol), sol) + res = __split_and_reshape(z_star, __getproperty(deq.model, Val(:split_idxs)), __getproperty(deq.model, Val(:scales))) return res, (; st..., model=model.st, solution, rng) @@ -153,8 +156,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq julia> model = DeepEquilibriumNetwork( - Parallel(+, Dense(2, 2; use_bias=false), - Dense(2, 2; use_bias=false)), + Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)), VCABM3(); verbose=false) DeepEquilibriumNetwork( model = Parallel( @@ -178,8 +180,8 @@ julia> model(ones(Float32, 2, 1), ps, st); See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref). """ -function DeepEquilibriumNetwork(model, solver; init=missing, - jacobian_regularization=nothing, +function DeepEquilibriumNetwork( + model, solver; init=missing, jacobian_regularization=nothing, problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...) where {pType} model isa AbstractExplicitLayer || (model = Lux.transform(model)) @@ -190,8 +192,8 @@ function DeepEquilibriumNetwork(model, solver; init=missing, elseif !(init isa AbstractExplicitLayer) init = Lux.transform(init) end - return DeepEquilibriumNetwork{pType}(init, model, solver, jacobian_regularization, - kwargs) + return DeepEquilibriumNetwork{pType}( + init, model, solver, jacobian_regularization, kwargs) end """ @@ -236,10 +238,8 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref). julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve julia> main_layers = ( - Parallel(+, Dense(4 => 4, tanh; use_bias=false), - Dense(4 => 4, tanh; use_bias=false)), - Dense(3 => 3, tanh), Dense(2 => 2, tanh), - Dense(1 => 1, tanh)) + Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)), + Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh)) (Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast)) julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh); @@ -252,8 +252,8 @@ julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Den Dense(2 => 4, tanh_fast) Dense(2 => 1, tanh_fast) Dense(1 => 4, tanh_fast) NoOpLayer() -julia> model = MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, - NewtonRaphson(), ((4,), (3,), (2,), (1,))) +julia> model = MultiScaleDeepEquilibriumNetwork( + main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,))) DeepEquilibriumNetwork( model = MultiScaleInputLayer{scales = 4}( model = Chain( @@ -315,9 +315,7 @@ julia> model(x, ps, st); ``` """ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, - post_fuse_layer::Union{Nothing, Tuple}, solver, scales; - jacobian_regularization=nothing, kwargs...) - @assert jacobian_regularization===nothing "Jacobian Regularization is not supported yet for MultiScale Models." + post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...) l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) @@ -327,8 +325,8 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma if post_fuse_layer === nothing model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales) else - model = MultiScaleInputLayer(Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), - split_idxs, scales) + model = MultiScaleInputLayer( + Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales) end return DeepEquilibriumNetwork(model, solver; kwargs...) @@ -347,14 +345,14 @@ If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Ne function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...) init = Chain(Parallel(nothing, init...), __flatten_vcat) - return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, post_fuse_layer, - solver, scales; init, kwargs...) + return MultiScaleDeepEquilibriumNetwork( + main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...) end function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, args...; kwargs...) - return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, post_fuse_layer, - args...; init=nothing, kwargs...) + return MultiScaleDeepEquilibriumNetwork( + main_layers, mapping_layers, post_fuse_layer, args...; init=nothing, kwargs...) end """ @@ -364,13 +362,13 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t `ODEProblem{false}`. """ function MultiScaleNeuralODE(args...; kwargs...) - return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., - problem_type=ODEProblem{false}) + return MultiScaleDeepEquilibriumNetwork( + args...; kwargs..., problem_type=ODEProblem{false}) end ## Generate Initial Condition -@inline function __get_initial_condition(deq::DEQ{pType, NoOpLayer}, x, ps, - st) where {pType} +@inline function __get_initial_condition( + deq::DEQ{pType, NoOpLayer}, x, ps, st) where {pType} zₓ = __zeros_init(__getproperty(deq.model, Val(:scales)), x) z, st_ = deq.model((zₓ, x), ps.model, st.model) return z, (; st..., model=st_) @@ -389,11 +387,11 @@ end scales end -constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N} = MultiScaleInputLayer{N} +function ConstructionBase.constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N} + return MultiScaleInputLayer{N} +end Lux.display_name(::MultiScaleInputLayer{N}) where {N} = "MultiScaleInputLayer{scales = $N}" -@truncate_stacktrace MultiScaleInputLayer 1 2 - function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S} return MultiScaleInputLayer{length(S)}(model, split_idxs, scales) end diff --git a/src/utils.jl b/src/utils.jl index 8de5a5d0..647636dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,8 +1,8 @@ -@generated function __split_and_reshape(x::AbstractMatrix, ::Val{idxs}, - ::Val{shapes}) where {idxs, shapes} +@generated function __split_and_reshape( + x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes} dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] varnames = map(_ -> gensym("x_view"), dims) - calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in 1:length(dims)] + calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in eachindex(dims)] return quote $(calls...) return tuple($(varnames...)) @@ -28,7 +28,7 @@ end function CRC.rrule(::typeof(__flatten_vcat), x) y = __flatten_vcat(x) project_x = CRC.ProjectTo(x) - function ∇__flatten_vcat(∂y) + ∇__flatten_vcat = @closure ∂y -> begin ∂y isa CRC.NoTangent && return (CRC.NoTangent(), CRC.NoTangent()) return CRC.NoTangent(), project_x(__split_and_reshape(∂y, x)) end @@ -52,7 +52,8 @@ end @inline __get_nfe(sol::ODESolution) = __get_nfe(sol.stats) @inline function __get_nfe(sol::NonlinearSolution) return ifelse(sol.stats === nothing, - ifelse(sol.original === nothing, -1, __get_nfe(sol.original)), __get_nfe(sol.stats)) + ifelse(sol.original === nothing, -1, __get_nfe(sol.original)), + __get_nfe(sol.stats)) end @inline __get_nfe(stats) = -1 @inline __get_nfe(stats::Union{SciMLBase.NLStats, SciMLBase.DEStats}) = stats.nf @@ -86,8 +87,8 @@ CRC.@non_differentiable __zeros_init(::Any, ::Any) ## Don't rely on SciMLSensitivity's choice @inline __default_sensealg(prob) = nothing -@inline function __gaussian_like(rng::AbstractRNG, x) - y = similar(x) +@inline function __gaussian_like(rng::AbstractRNG, x::AbstractArray) + y = similar(x)::typeof(x) randn!(rng, y) return y end @@ -95,8 +96,9 @@ end CRC.@non_differentiable __gaussian_like(::Any...) # Jacobian Stabilization -function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng) - __f = u -> model((u, x), ps) +## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 +function __estimate_jacobian_trace(ad::AutoFiniteDiff, model, z, x, rng) + __f = @closure u -> model((u, x)) res = zero(eltype(x)) ϵ = cbrt(eps(typeof(res))) ϵ⁻¹ = inv(ϵ) @@ -117,4 +119,4 @@ function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng) return res end -__estimate_jacobian_trace(::Nothing, model, ps, z, x, rng) = zero(eltype(x)) +__estimate_jacobian_trace(::Nothing, model, z, x, rng) = zero(eltype(x)) diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 04cd6405..00000000 --- a/test/Project.toml +++ /dev/null @@ -1,26 +0,0 @@ -[deps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[compat] -Aqua = "0.8" diff --git a/test/layers.jl b/test/layers.jl deleted file mode 100644 index 24dcf798..00000000 --- a/test/layers.jl +++ /dev/null @@ -1,180 +0,0 @@ -using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq, - SciMLSensitivity, SciMLBase, Test - -include("test_utils.jl") - -function loss_function(model, x, ps, st) - y, st = model(x, ps, st) - l1 = y isa Tuple ? sum(Base.Fix1(sum, abs2), y) : sum(abs2, y) - l2 = st.solution.jacobian_loss - l3 = sum(abs2, st.solution.z_star .- st.solution.u0) - return l1 + l2 + l3 -end - -@testset "DeepEquilibriumNetwork: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = __get_prng(0) - - base_models = [ - Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)), - Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1)) - ] - init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)] - x_sizes = [(2, 14), (3, 3, 1, 3)] - - model_type = (:deq, :skipdeq, :skipregdeq) - solvers = (VCAB3(), Tsit5(), - NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), - SimpleLimitedMemoryBroyden()) - jacobian_regularizations = Any[nothing, AutoZygote()] - !ongpu && push!(jacobian_regularizations, AutoFiniteDiff()) - - @testset "Solver: $(__nameof(solver))" for solver in solvers, - mtype in model_type, jacobian_regularization in jacobian_regularizations - - @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip( - base_models, - init_models, x_sizes) - model = if mtype === :deq - DeepEquilibriumNetwork(base_model, solver; jacobian_regularization) - elseif mtype === :skipdeq - SkipDeepEquilibriumNetwork(base_model, init_model, solver; - jacobian_regularization) - elseif mtype === :skipregdeq - SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization) - end - - ps, st = Lux.setup(rng, model) |> dev - @test st.solution == DeepEquilibriumSolution() - - x = randn(rng, Float32, x_size...) |> dev - z, st = model(x, ps, st) - - opt_broken = solver isa SimpleLimitedMemoryBroyden - @jet model(x, ps, st) opt_broken=opt_broken - - @test all(isfinite, z) - @test size(z) == size(x) - @test st.solution isa DeepEquilibriumSolution - @test maximum(abs, st.solution.residual) ≤ 1e-3 - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - - ps, st = Lux.setup(rng, model) |> dev - st = Lux.update_state(st, :fixed_depth, Val(10)) - @test st.solution == DeepEquilibriumSolution() - - z, st = model(x, ps, st) - @jet model(x, ps, st) - - @test all(isfinite, z) - @test size(z) == size(x) - @test st.solution isa DeepEquilibriumSolution - @test st.solution.nfe == 10 - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - end - end -end - -@testset "MultiScaleDeepEquilibriumNetwork: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = __get_prng(0) - - main_layers = [ - (Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)), - __get_dense_layer(3 => 3), __get_dense_layer(2 => 2), - __get_dense_layer(1 => 1)) - ] - - mapping_layers = [ - [NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1); - __get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1); - __get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1); - __get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()] - ] - - init_layers = [ - (__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), __get_dense_layer(4 => 2), - __get_dense_layer(4 => 1)) - ] - - x_sizes = [(4, 3)] - scales = [((4,), (3,), (2,), (1,))] - - model_type = (:deq, :skipdeq, :skipregdeq, :node) - solvers = (VCAB3(), Tsit5(), - NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), - SimpleLimitedMemoryBroyden()) - jacobian_regularizations = (nothing,) - - for mtype in model_type, jacobian_regularization in jacobian_regularizations - @testset "Solver: $(__nameof(solver))" for solver in solvers - @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip( - main_layers, - mapping_layers, init_layers, x_sizes, scales) - model = if mtype === :deq - MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, - solver, scale; jacobian_regularization) - elseif mtype === :skipdeq - MultiScaleSkipDeepEquilibriumNetwork( - main_layer, mapping_layer, nothing, - init_layer, solver, scale; jacobian_regularization) - elseif mtype === :skipregdeq - MultiScaleSkipDeepEquilibriumNetwork( - main_layer, mapping_layer, nothing, - solver, scale; jacobian_regularization) - elseif mtype === :node - solver isa SciMLBase.AbstractODEAlgorithm || continue - MultiScaleNeuralODE(main_layer, mapping_layer, nothing, solver, scale; - jacobian_regularization) - end - - ps, st = Lux.setup(rng, model) |> dev - @test st.solution == DeepEquilibriumSolution() - - x = randn(rng, Float32, x_size...) |> dev - z, st = model(x, ps, st) - z_ = DEQs.__flatten_vcat(z) - - opt_broken = solver isa SimpleLimitedMemoryBroyden - @jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch - - @test all(isfinite, z_) - @test size(z_) == (sum(prod, scale), size(x, ndims(x))) - @test st.solution isa DeepEquilibriumSolution - if st.solution.residual !== nothing - @test maximum(abs, st.solution.residual) ≤ 1e-3 - end - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - - ps, st = Lux.setup(rng, model) |> dev - st = Lux.update_state(st, :fixed_depth, Val(10)) - @test st.solution == DeepEquilibriumSolution() - - z, st = model(x, ps, st) - z_ = DEQs.__flatten_vcat(z) - opt_broken = jacobian_regularization isa AutoZygote - @jet model(x, ps, st) opt_broken=opt_broken - - @test all(isfinite, z_) - @test size(z_) == (sum(prod, scale), size(x, ndims(x))) - @test st.solution isa DeepEquilibriumSolution - @test st.solution.nfe == 10 - - _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) - - @test __is_finite_gradient(gs_x) - @test __is_finite_gradient(gs_ps) - end - end - end -end diff --git a/test/layers_tests.jl b/test/layers_tests.jl new file mode 100644 index 00000000..aa19ea45 --- /dev/null +++ b/test/layers_tests.jl @@ -0,0 +1,180 @@ +@testsetup module LayersTestSetup + +using NonlinearSolve, OrdinaryDiffEq + +function loss_function(model, x, ps, st) + y, st = model(x, ps, st) + l1 = y isa Tuple ? sum(Base.Fix1(sum, abs2), y) : sum(abs2, y) + l2 = st.solution.jacobian_loss + l3 = sum(abs2, st.solution.z_star .- st.solution.u0) + return l1 + l2 + l3 +end + +SOLVERS = (VCAB3(), Tsit5(), NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), + SimpleLimitedMemoryBroyden()) + +export loss_function, SOLVERS + +end + +@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] begin + using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote + + rng = __get_prng(0) + + base_models = [Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)), + Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1))] + init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)] + x_sizes = [(2, 14), (3, 3, 1, 3)] + + model_type = (:deq, :skipdeq, :skipregdeq) + _jacobian_regularizations = (nothing, AutoZygote(), AutoFiniteDiff()) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] : + _jacobian_regularizations + + @testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS, + mtype in model_type, + jacobian_regularization in jacobian_regularizations + + @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip( + base_models, init_models, x_sizes) + model = if mtype === :deq + DeepEquilibriumNetwork(base_model, solver; jacobian_regularization) + elseif mtype === :skipdeq + SkipDeepEquilibriumNetwork( + base_model, init_model, solver; jacobian_regularization) + elseif mtype === :skipregdeq + SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization) + end + + ps, st = Lux.setup(rng, model) |> dev + @test st.solution == DeepEquilibriumSolution() + + x = randn(rng, Float32, x_size...) |> dev + z, st = model(x, ps, st) + + opt_broken = solver isa SimpleLimitedMemoryBroyden + @jet model(x, ps, st) opt_broken=opt_broken + + @test all(isfinite, z) + @test size(z) == size(x) + @test st.solution isa DeepEquilibriumSolution + @test maximum(abs, st.solution.residual) ≤ 1e-3 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + + ps, st = Lux.setup(rng, model) |> dev + st = Lux.update_state(st, :fixed_depth, Val(10)) + @test st.solution == DeepEquilibriumSolution() + + z, st = model(x, ps, st) + @jet model(x, ps, st) + + @test all(isfinite, z) + @test size(z) == size(x) + @test st.solution isa DeepEquilibriumSolution + @test st.solution.nfe == 10 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + end + end + end +end + +@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] begin + using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote + + rng = __get_prng(0) + + main_layers = [(Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)), + __get_dense_layer(3 => 3), __get_dense_layer(2 => 2), __get_dense_layer(1 => 1))] + + mapping_layers = [[NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1); + __get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1); + __get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1); + __get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()]] + + init_layers = [(__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), + __get_dense_layer(4 => 2), __get_dense_layer(4 => 1))] + + x_sizes = [(4, 3)] + scales = [((4,), (3,), (2,), (1,))] + + model_type = (:deq, :skipdeq, :skipregdeq, :node) + jacobian_regularizations = (nothing,) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + @testset "Solver: $(__nameof(solver))" for solver in SOLVERS, + mtype in model_type, + jacobian_regularization in jacobian_regularizations + + @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip( + main_layers, mapping_layers, init_layers, x_sizes, scales) + model = if mtype === :deq + MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, + solver, scale; jacobian_regularization) + elseif mtype === :skipdeq + MultiScaleSkipDeepEquilibriumNetwork( + main_layer, mapping_layer, nothing, init_layer, + solver, scale; jacobian_regularization) + elseif mtype === :skipregdeq + MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing, + solver, scale; jacobian_regularization) + elseif mtype === :node + solver isa SciMLBase.AbstractODEAlgorithm || continue + MultiScaleNeuralODE(main_layer, mapping_layer, nothing, + solver, scale; jacobian_regularization) + end + + ps, st = Lux.setup(rng, model) |> dev + @test st.solution == DeepEquilibriumSolution() + + x = randn(rng, Float32, x_size...) |> dev + z, st = model(x, ps, st) + z_ = DEQs.__flatten_vcat(z) + + opt_broken = solver isa SimpleLimitedMemoryBroyden + @jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch + + @test all(isfinite, z_) + @test size(z_) == (sum(prod, scale), size(x, ndims(x))) + @test st.solution isa DeepEquilibriumSolution + if st.solution.residual !== nothing + @test maximum(abs, st.solution.residual) ≤ 1e-3 + end + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + + ps, st = Lux.setup(rng, model) |> dev + st = Lux.update_state(st, :fixed_depth, Val(10)) + @test st.solution == DeepEquilibriumSolution() + + z, st = model(x, ps, st) + z_ = DEQs.__flatten_vcat(z) + opt_broken = jacobian_regularization isa AutoZygote + @jet model(x, ps, st) opt_broken=opt_broken + + @test all(isfinite, z_) + @test size(z_) == (sum(prod, scale), size(x, ndims(x))) + @test st.solution isa DeepEquilibriumSolution + @test st.solution.nfe == 10 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + end + end + end +end diff --git a/test/qa.jl b/test/qa.jl deleted file mode 100644 index 94eb43f6..00000000 --- a/test/qa.jl +++ /dev/null @@ -1,7 +0,0 @@ -using DeepEquilibriumNetworks, Aqua, Test -import ChainRulesCore as CRC - -@testset "Aqua" begin - Aqua.test_all(DeepEquilibriumNetworks; ambiguities=false) - Aqua.test_ambiguities(DeepEquilibriumNetworks; recursive=false) -end diff --git a/test/qa_tests.jl b/test/qa_tests.jl new file mode 100644 index 00000000..2dd1d11e --- /dev/null +++ b/test/qa_tests.jl @@ -0,0 +1,17 @@ +@testitem "Aqua" begin + using Aqua + + Aqua.test_all(DeepEquilibriumNetworks; ambiguities=false) + Aqua.test_ambiguities(DeepEquilibriumNetworks; recursive=false) +end + +@testitem "ExplicitImports" begin + import SciMLSensitivity, Zygote + + using ExplicitImports + + # Skip our own packages + @test check_no_implicit_imports(DeepEquilibriumNetworks) === nothing + ## AbstractRNG seems to be a spurious detection in LuxFluxExt + @test check_no_stale_explicit_imports(DeepEquilibriumNetworks) === nothing +end diff --git a/test/runtests.jl b/test/runtests.jl index 045828fa..8ba7978a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,3 @@ -using SafeTestsets, Test, TestSetExtensions +using ReTestItems -@testset ExtendedTestSet "Deep Equilibrium Networks" begin - @safetestset "Quality Assurance" include("qa.jl") - @safetestset "Utilities" include("utils.jl") - @safetestset "Layers" include("layers.jl") -end +ReTestItems.runtests(@__DIR__) diff --git a/test/test_utils.jl b/test/shared_testsetup.jl similarity index 73% rename from test/test_utils.jl rename to test/shared_testsetup.jl index b9268716..b22de31b 100644 --- a/test/test_utils.jl +++ b/test/shared_testsetup.jl @@ -1,3 +1,5 @@ +@testsetup module SharedTestSetup + using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote import LuxTestUtils: @jet using LuxCUDA @@ -35,15 +37,19 @@ const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" cuda_testing() = LuxCUDA.functional() && (GROUP == "All" || GROUP == "CUDA") -if !@isdefined(MODES) - const MODES = begin - cpu_mode = ("CPU", Array, LuxCPUDevice(), false) - cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true) +const MODES = begin + cpu_mode = ("CPU", Array, LuxCPUDevice(), false) + cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true) - modes = [] - cpu_testing() && push!(modes, cpu_mode) - cuda_testing() && push!(modes, cuda_mode) + modes = [] + cpu_testing() && push!(modes, cpu_mode) + cuda_testing() && push!(modes, cuda_mode) + + modes +end + +export Lux, LuxCore, LuxLib +export MODES, __get_dense_layer, __get_conv_layer, __is_finite_gradient, __get_prng, + __nameof, @jet - modes - end end diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 2c0057a3..00000000 --- a/test/utils.jl +++ /dev/null @@ -1,38 +0,0 @@ -using DeepEquilibriumNetworks, LinearAlgebra, SciMLBase, Test - -include("test_utils.jl") - -@testset "split_and_reshape: $mode" for (mode, aType, dev, ongpu) in MODES - x1 = ones(Float32, 4, 4) |> aType - x2 = fill(0.5f0, 2, 4) |> aType - x3 = zeros(Float32, 1, 4) |> aType - - x = vcat(x1, x2, x3) - split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1)))) - shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1))) - x_split = DEQs.__split_and_reshape(x, split_idxs, shapes) - - @test x1 == x_split[1] - @test x2 == x_split[2] - @test x3 == x_split[3] - - @jet DEQs.__split_and_reshape(x, split_idxs, shapes) -end - -@testset "unrolled_mode check" begin - @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(10))) - @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(0))) - @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(10)))) - @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(0)))) -end - -@testset "get unrolled_mode" begin - @test DEQs.__get_unrolled_depth(Val(10)) == 10 - @test DEQs.__get_unrolled_depth((; fixed_depth=Val(10))) == 10 -end - -@testset "deep equilibrium solution" begin - sol = @test_nowarn DeepEquilibriumSolution(randn(10), randn(10), randn(10), 0.4, 10, - nothing) - @test_nowarn println(sol) -end diff --git a/test/utils_tests.jl b/test/utils_tests.jl new file mode 100644 index 00000000..2d114a79 --- /dev/null +++ b/test/utils_tests.jl @@ -0,0 +1,38 @@ +@testitem "split_and_reshape" setup=[SharedTestSetup] begin + for (mode, aType, dev, ongpu) in MODES + x1 = ones(Float32, 4, 4) |> aType + x2 = fill(0.5f0, 2, 4) |> aType + x3 = zeros(Float32, 1, 4) |> aType + + x = vcat(x1, x2, x3) + split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1)))) + shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1))) + x_split = DEQs.__split_and_reshape(x, split_idxs, shapes) + + @test x1 == x_split[1] + @test x2 == x_split[2] + @test x3 == x_split[3] + + @jet DEQs.__split_and_reshape(x, split_idxs, shapes) + end +end + +@testitem "unrolled_mode check" setup=[SharedTestSetup] begin + using SciMLBase + + @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(10))) + @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(0))) + @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(10)))) + @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(0)))) +end + +@testitem "get unrolled_mode" setup=[SharedTestSetup] begin + @test DEQs.__get_unrolled_depth(Val(10)) == 10 + @test DEQs.__get_unrolled_depth((; fixed_depth=Val(10))) == 10 +end + +@testitem "deep equilibrium solution" setup=[SharedTestSetup] begin + sol = @test_nowarn DeepEquilibriumSolution( + randn(10), randn(10), randn(10), 0.4, 10, nothing) + @test_nowarn println(sol) +end