Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReverseDiff weak dep #247

Merged
merged 12 commits into from
Feb 7, 2023
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand All @@ -30,12 +31,16 @@ FillArrays = "0.11, 0.12, 0.13"
GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
StaticArraysCore = "1.1"
SymbolicIndexingInterface = "0.1, 0.2"
Tables = "1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
RecursiveArrayToolsTrackerExt = "Tracker"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -44,11 +49,16 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]
test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]

[weakdeps]
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
18 changes: 18 additions & 0 deletions ext/RecursiveArrayToolsTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module RecursiveArrayToolsTrackerExt

import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)

function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
@inbounds for i in eachindex(a)
b[i] = copy(a[i])
end
end

end
7 changes: 7 additions & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
T(xs), ȳ -> (NoTangent(), ȳ)
end

import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
end
end

export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
AllObserved, vecarr_to_vectors, tuples

Expand Down
8 changes: 4 additions & 4 deletions src/tabletraits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
N = length(A.u[1])
names = [
:timestamp,
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
(!(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? (states(A.sc)[i] for i in 1:N) :
(Symbol("value", i) for i in 1:N))...,
]
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
else
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
names = [:timestamp, !(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? states(A.sc)[1] : :value]
types = Type[eltype(A.t), VT]
end
return AbstractDiffEqArrayRows(names, types, A.t, A.u)
Expand All @@ -31,8 +31,8 @@ struct AbstractDiffEqArrayRows{T, U}
u::U
end
function AbstractDiffEqArrayRows(names, types, t, u)
AbstractDiffEqArrayRows(names, types,
Dict(nm => i for (i, nm) in enumerate(names)), t, u)
AbstractDiffEqArrayRows(Symbol.(names), types,
Dict(Symbol(nm) => i for (i, nm) in enumerate(names)), t, u)
end

Base.length(x::AbstractDiffEqArrayRows) = length(x.u)
Expand Down
4 changes: 3 additions & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[deps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[compat]
ModelingToolkit = "8.33"
OrdinaryDiffEq = "6.31"
OrdinaryDiffEq = "6.31"
Tracker = "0.2"
7 changes: 7 additions & 0 deletions test/downstream/TrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using RecursiveArrayTools, Tracker, Test

x = [5.0]
a = [Tracker.TrackedArray(x)]
b = [Tracker.TrackedArray(copy([5.2]))]
RecursiveArrayTools.recursivecopy!(a,b)
@test a[1][1] == 5.2
29 changes: 16 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Pkg
using RecursiveArrayTools
using Test
using Aqua
using SafeTestsets

Aqua.test_all(RecursiveArrayTools, ambiguities = false)
@test_broken isempty(Test.detect_ambiguities(RecursiveArrayTools))
const GROUP = get(ENV, "GROUP", "All")
Expand All @@ -21,26 +23,27 @@ end

@time begin
if GROUP == "Core" || GROUP == "All"
@time @testset "Utils Tests" begin include("utils_test.jl") end
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
@time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
@time @testset "Table traits" begin include("tabletraits.jl") end
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
@time @testset "Upstream Tests" begin include("upstream.jl") end
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
@time @safetestset "Utils Tests" begin include("utils_test.jl") end
@time @safetestset "Partitions Tests" begin include("partitions_test.jl") end
@time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
@time @safetestset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
@time @safetestset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
@time @safetestset "Table traits" begin include("tabletraits.jl") end
@time @safetestset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
@time @safetestset "Linear Algebra Tests" begin include("linalg.jl") end
@time @safetestset "Upstream Tests" begin include("upstream.jl") end
@time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
end

if !is_APPVEYOR && GROUP == "Downstream"
activate_downstream_env()
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
@time @testset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end
@time @safetestset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
@time @safetestset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end
@time @safetestset "TrackerExt" begin include("downstream/TrackerExt.jl") end
end

if !is_APPVEYOR && GROUP == "GPU"
activate_gpu_env()
@time @testset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
@time @safetestset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
end
end