Skip to content

Commit

Permalink
fixup! fix: fix promote_u0 when p isa TrackedArray and u0 is not
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 3, 2024
1 parent 75c02c3 commit aa37f83
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using DiffEqBase
import DiffEqBase: value
import ReverseDiff
import DiffEqBase.ArrayInterface
import DiffEqBase.ForwardDiff

function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
DiffEqBase.anyeltypedual(V, Val{counter})
Expand Down Expand Up @@ -36,7 +37,8 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ArrayInterface.aos_to_soa(eltype(p).(u0))
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0))
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)

# Support adaptive with non-tracked time
Expand Down

0 comments on commit aa37f83

Please sign in to comment.