Skip to content

Commit

Permalink
test: add tests for ReverseDiff dual detection and promotion
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 3, 2024
1 parent eb42bc2 commit 556aa65
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
18 changes: 18 additions & 0 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DiffEqBase, ForwardDiff, Test, InteractiveUtils
using ReverseDiff, SciMLStructures
using Plots

u0 = 2.0
Expand Down Expand Up @@ -348,3 +349,20 @@ foo = SciMLBase.build_solution(
prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0])
DiffEqBase.anyeltypedual((; x = foo))
DiffEqBase.anyeltypedual((; x = foo, y = prob.f))

@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3))
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3))

struct Foo{T}
tunables::T
end

SciMLStructures.isscimlstructure(::Foo) = true
SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::Foo) = f.tunables, x -> Foo(x), true

@test DiffEqBase.promote_u0(ones(3), Foo(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray
@test DiffEqBase.promote_u0(1.0, Foo(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal
@test DiffEqBase.promote_u0(ones(3), Foo(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(1.0, Foo(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}

0 comments on commit 556aa65

Please sign in to comment.