From aa37f83b8f78fc7fa2dcaec70129c2a2c86d7cc0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 17:45:40 +0530 Subject: [PATCH] fixup! fix: fix promote_u0 when `p isa TrackedArray` and `u0` is not --- ext/DiffEqBaseReverseDiffExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index a47777681..9272bc1ec 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -4,6 +4,7 @@ using DiffEqBase import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface +import DiffEqBase.ForwardDiff function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} DiffEqBase.anyeltypedual(V, Val{counter}) @@ -36,7 +37,8 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) u0 end -DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ArrayInterface.aos_to_soa(eltype(p).(u0)) +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0)) DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) # Support adaptive with non-tracked time