diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index a47777681..9272bc1ec 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -4,6 +4,7 @@ using DiffEqBase import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface +import DiffEqBase.ForwardDiff function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} DiffEqBase.anyeltypedual(V, Val{counter}) @@ -36,7 +37,8 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) u0 end -DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ArrayInterface.aos_to_soa(eltype(p).(u0)) +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0)) DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) # Support adaptive with non-tracked time