From 556aa656573ff2a1b3ef556ced6dc76d47dce638 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 16:51:41 +0530 Subject: [PATCH] test: add tests for ReverseDiff dual detection and promotion --- Project.toml | 2 +- test/forwarddiff_dual_detection.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 47372939d..76b9359fc 100644 --- a/Project.toml +++ b/Project.toml @@ -133,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"] +test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"] diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 8f269c478..1d2ae4d4d 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -1,4 +1,5 @@ using DiffEqBase, ForwardDiff, Test, InteractiveUtils +using ReverseDiff, SciMLStructures using Plots u0 = 2.0 @@ -348,3 +349,20 @@ foo = SciMLBase.build_solution( prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0]) DiffEqBase.anyeltypedual((; x = foo)) DiffEqBase.anyeltypedual((; x = foo, y = prob.f)) + +@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any +@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any +@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3)) +@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3)) + +struct Foo{T} + tunables::T +end + +SciMLStructures.isscimlstructure(::Foo) = true +SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::Foo) = f.tunables, x -> Foo(x), true + +@test DiffEqBase.promote_u0(ones(3), Foo(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray +@test DiffEqBase.promote_u0(1.0, Foo(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal +@test DiffEqBase.promote_u0(ones(3), Foo(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual} +@test DiffEqBase.promote_u0(1.0, Foo(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}