From 8021ff1f02e4ef2767dcdee3a07ec721dc4bb9b8 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 13:22:35 -0500 Subject: [PATCH 01/12] Add ReverseDiff weak dep with backwards compatibility. Comes from SciMLSensitivity.jl back compat split. --- Project.toml | 6 ++++++ ext/RecursiveArrayToolsReverseDiffExt.jl | 15 +++++++++++++++ src/RecursiveArrayTools.jl | 7 +++++++ 3 files changed, 28 insertions(+) create mode 100644 ext/RecursiveArrayToolsReverseDiffExt.jl diff --git a/Project.toml b/Project.toml index e182604f..b906d24a 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[extensions] +RecursiveArrayToolsReverseDiffExt = "ReverseDiff" + [compat] Adapt = "3" ArrayInterfaceCore = "0.1.1" diff --git a/ext/RecursiveArrayToolsReverseDiffExt.jl b/ext/RecursiveArrayToolsReverseDiffExt.jl new file mode 100644 index 00000000..51d447de --- /dev/null +++ b/ext/RecursiveArrayToolsReverseDiffExt.jl @@ -0,0 +1,15 @@ +module RecursiveArrayToolsReverseDiffExt + +function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N}, + a::AbstractArray{T2, N}) where { + T <: + Tracker.TrackedArray, + T2 <: + Tracker.TrackedArray, + N} + @inbounds for i in eachindex(a) + b[i] = copy(a[i]) + end +end + +end \ No newline at end of file diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index bcb269f3..ec0127be 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -42,6 +42,13 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, T(xs), ȳ -> (NoTangent(), ȳ) end +import Requires +function __init__() + @static if !isdefined(Base, :get_extension) + Requires.@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/RecursiveArrayToolsReverseDiffExt.jl") end + end +end + export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray, AllObserved, vecarr_to_vectors, tuples From fb2cacb3743dcd68dff5e2e1cf9e549a87baf122 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 13:24:10 -0500 Subject: [PATCH 02/12] add missing dep --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index b906d24a..03e838b7 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" From 03878382f19d0fa210ce4cbcf98959a74e6f97d8 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 13:29:21 -0500 Subject: [PATCH 03/12] Tracker oops --- Project.toml | 4 ++-- ...ReverseDiffExt.jl => RecursiveArrayToolsTrackerExt.jl} | 5 ++++- src/RecursiveArrayTools.jl | 8 ++++---- 3 files changed, 10 insertions(+), 7 deletions(-) rename ext/{RecursiveArrayToolsReverseDiffExt.jl => RecursiveArrayToolsTrackerExt.jl} (89%) diff --git a/Project.toml b/Project.toml index 03e838b7..cfb0d9eb 100644 --- a/Project.toml +++ b/Project.toml @@ -22,10 +22,10 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] -RecursiveArrayToolsReverseDiffExt = "ReverseDiff" +RecursiveArrayToolsTrackerExt = "Tracker" [compat] Adapt = "3" diff --git a/ext/RecursiveArrayToolsReverseDiffExt.jl b/ext/RecursiveArrayToolsTrackerExt.jl similarity index 89% rename from ext/RecursiveArrayToolsReverseDiffExt.jl rename to ext/RecursiveArrayToolsTrackerExt.jl index 51d447de..25202d25 100644 --- a/ext/RecursiveArrayToolsReverseDiffExt.jl +++ b/ext/RecursiveArrayToolsTrackerExt.jl @@ -1,4 +1,7 @@ -module RecursiveArrayToolsReverseDiffExt +module RecursiveArrayToolsTrackerExt + +import RecursiveArrayTools +import Tracker function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T2, N}) where { diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index ec0127be..40e6737c 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -42,10 +42,10 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, T(xs), ȳ -> (NoTangent(), ȳ) end -import Requires -function __init__() - @static if !isdefined(Base, :get_extension) - Requires.@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/RecursiveArrayToolsReverseDiffExt.jl") end +@static if !isdefined(Base, :get_extension) + import Requires + function __init__() + Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end end end From 7daf6d24cbbe2f181f963a78c89c160b35c537eb Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 13:34:54 -0500 Subject: [PATCH 04/12] add a downstream test --- test/downstream/Project.toml | 4 +++- test/downstream/TrackerExt.jl | 7 +++++++ test/runtests.jl | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 test/downstream/TrackerExt.jl diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 21059020..89471dae 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,7 +1,9 @@ [deps] ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] ModelingToolkit = "8.33" -OrdinaryDiffEq = "6.31" \ No newline at end of file +OrdinaryDiffEq = "6.31" +Tracker = "0.2" \ No newline at end of file diff --git a/test/downstream/TrackerExt.jl b/test/downstream/TrackerExt.jl new file mode 100644 index 00000000..275f2336 --- /dev/null +++ b/test/downstream/TrackerExt.jl @@ -0,0 +1,7 @@ +using RecursiveArrayTools, Tracker, Test + +x = [5.0] +a = [Tracker.TrackedArray(x)] +b = [Tracker.TrackedArray(copy([5.2]))] +RecursiveArrayTools.recursivecopy!(a,b) +@test a[1][1] == 5.2 \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5d99e528..61686e7e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,7 @@ end activate_downstream_env() @time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end @time @testset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end + @time @testset "TrackerExt" begin include("downstream/TrackerExt.jl") end end if !is_APPVEYOR && GROUP == "GPU" From 1a33a6474e1d918dcb8d9941c573efdf5755a7ee Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 13:40:53 -0500 Subject: [PATCH 05/12] add compat entry --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index cfb0d9eb..a1ac1a27 100644 --- a/Project.toml +++ b/Project.toml @@ -37,6 +37,7 @@ FillArrays = "0.11, 0.12, 0.13" GPUArraysCore = "0.1" IteratorInterfaceExtensions = "1" RecipesBase = "0.7, 0.8, 1.0" +Requires = "1.0" StaticArraysCore = "1.1" SymbolicIndexingInterface = "0.1, 0.2" Tables = "1" From dcb894a14105d9eb08e63234835ee85702215b75 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 16:40:59 -0500 Subject: [PATCH 06/12] fix weakdep --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a1ac1a27..9990de6b 100644 --- a/Project.toml +++ b/Project.toml @@ -55,6 +55,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From f294f7ad36bb49edf23a8baedfb69b62d3fd41fc Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 18:03:05 -0500 Subject: [PATCH 07/12] fix requires import --- src/RecursiveArrayTools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 40e6737c..c93ca011 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -42,8 +42,8 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, T(xs), ȳ -> (NoTangent(), ȳ) end +import Requires @static if !isdefined(Base, :get_extension) - import Requires function __init__() Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end end From 5155156aef18e2547d0e43f1d65138b4adc77c75 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 18:57:18 -0500 Subject: [PATCH 08/12] appease the pkg gods --- Project.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 9990de6b..52e49d67 100644 --- a/Project.toml +++ b/Project.toml @@ -21,12 +21,6 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" -[weakdeps] -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[extensions] -RecursiveArrayToolsTrackerExt = "Tracker" - [compat] Adapt = "3" ArrayInterfaceCore = "0.1.1" @@ -44,6 +38,9 @@ Tables = "1" ZygoteRules = "0.2" julia = "1.6" +[extensions] +RecursiveArrayToolsTrackerExt = "Tracker" + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -61,3 +58,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"] + +[weakdeps] +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" \ No newline at end of file From 3176683b884ce2f327c22b440f9658638220ccdf Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 20:40:45 -0500 Subject: [PATCH 09/12] load correctly --- ext/RecursiveArrayToolsTrackerExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsTrackerExt.jl b/ext/RecursiveArrayToolsTrackerExt.jl index 25202d25..43c7f4a7 100644 --- a/ext/RecursiveArrayToolsTrackerExt.jl +++ b/ext/RecursiveArrayToolsTrackerExt.jl @@ -1,7 +1,7 @@ module RecursiveArrayToolsTrackerExt import RecursiveArrayTools -import Tracker +isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker) function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T2, N}) where { From 9f9e88c3974f1fef3e5015412ce76ff2bdf6503f Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 6 Feb 2023 21:26:11 -0500 Subject: [PATCH 10/12] fix tables interface --- src/tabletraits.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tabletraits.jl b/src/tabletraits.jl index df4ccef7..c4952573 100644 --- a/src/tabletraits.jl +++ b/src/tabletraits.jl @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray) N = length(A.u[1]) names = [ :timestamp, - (A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) : + (A.sc !== nothing ? (states(A.sc)[i] for i in 1:N) : (Symbol("value", i) for i in 1:N))..., ] types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...] else - names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value] + names = [:timestamp, A.sc !== nothing ? states(A.sc)[1] : :value] types = Type[eltype(A.t), VT] end return AbstractDiffEqArrayRows(names, types, A.t, A.u) @@ -31,8 +31,8 @@ struct AbstractDiffEqArrayRows{T, U} u::U end function AbstractDiffEqArrayRows(names, types, t, u) - AbstractDiffEqArrayRows(names, types, - Dict(nm => i for (i, nm) in enumerate(names)), t, u) + AbstractDiffEqArrayRows(Symbol.(names), types, + Dict(Symbol(nm) => i for (i, nm) in enumerate(names)), t, u) end Base.length(x::AbstractDiffEqArrayRows) = length(x.u) From d93a32797606af328274d885e32476d8846795f4 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 7 Feb 2023 05:51:14 -0500 Subject: [PATCH 11/12] switch to safetestset --- Project.toml | 3 ++- test/runtests.jl | 30 ++++++++++++++++-------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 52e49d67..802bdf69 100644 --- a/Project.toml +++ b/Project.toml @@ -49,6 +49,7 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -57,7 +58,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"] +test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"] [weakdeps] Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 61686e7e..591ad998 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using Pkg using RecursiveArrayTools using Test using Aqua +using SafeTestsets + Aqua.test_all(RecursiveArrayTools, ambiguities = false) @test_broken isempty(Test.detect_ambiguities(RecursiveArrayTools)) const GROUP = get(ENV, "GROUP", "All") @@ -21,27 +23,27 @@ end @time begin if GROUP == "Core" || GROUP == "All" - @time @testset "Utils Tests" begin include("utils_test.jl") end - @time @testset "Partitions Tests" begin include("partitions_test.jl") end - @time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end - @time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end - @time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end - @time @testset "Table traits" begin include("tabletraits.jl") end - @time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end - @time @testset "Linear Algebra Tests" begin include("linalg.jl") end - @time @testset "Upstream Tests" begin include("upstream.jl") end - @time @testset "Adjoint Tests" begin include("adjoints.jl") end + @time @safetestset "Utils Tests" begin include("utils_test.jl") end + @time @safetestset "Partitions Tests" begin include("partitions_test.jl") end + @time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end + @time @safetestset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end + @time @safetestset "VecOfArr Interface Tests" begin include("interface_tests.jl") end + @time @safetestset "Table traits" begin include("tabletraits.jl") end + @time @safetestset "StaticArrays Tests" begin include("copy_static_array_test.jl") end + @time @safetestset "Linear Algebra Tests" begin include("linalg.jl") end + @time @safetestset "Upstream Tests" begin include("upstream.jl") end + @time @safetestset "Adjoint Tests" begin include("adjoints.jl") end end if !is_APPVEYOR && GROUP == "Downstream" activate_downstream_env() - @time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end - @time @testset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end - @time @testset "TrackerExt" begin include("downstream/TrackerExt.jl") end + @time @safetestset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end + @time @safetestset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end + @time @safetestset "TrackerExt" begin include("downstream/TrackerExt.jl") end end if !is_APPVEYOR && GROUP == "GPU" activate_gpu_env() - @time @testset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end + @time @safetestset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end end end From 2746e50ac26b0522cb28b330354bc10cf3b681b8 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 7 Feb 2023 06:07:59 -0500 Subject: [PATCH 12/12] properly check null symbolcache --- src/tabletraits.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tabletraits.jl b/src/tabletraits.jl index c4952573..f8799aa5 100644 --- a/src/tabletraits.jl +++ b/src/tabletraits.jl @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray) N = length(A.u[1]) names = [ :timestamp, - (A.sc !== nothing ? (states(A.sc)[i] for i in 1:N) : + (!(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? (states(A.sc)[i] for i in 1:N) : (Symbol("value", i) for i in 1:N))..., ] types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...] else - names = [:timestamp, A.sc !== nothing ? states(A.sc)[1] : :value] + names = [:timestamp, !(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? states(A.sc)[1] : :value] types = Type[eltype(A.t), VT] end return AbstractDiffEqArrayRows(names, types, A.t, A.u)