diff --git a/Project.toml b/Project.toml index 9cccc398b..e5aba9390 100644 --- a/Project.toml +++ b/Project.toml @@ -37,15 +37,21 @@ GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -DistributionsExt = "Distributions" -MeasurementsExt = "Measurements" -MPIExt = "MPI" -MonteCarloMeasurementsExt = "MonteCarloMeasurements" -GeneralizedGeneratedExt = "GeneralizedGenerated" -UnitfulExt = "Unitful" +DiffEqBaseZygoteExt = "Zygote" +DiffEqBaseReverseDiffExt = "ReverseDiff" +DiffEqBaseTrackerExt = "Tracker" +DiffEqBaseDistributionsExt = "Distributions" +DiffEqBaseMeasurementsExt = "Measurements" +DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements" +DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated" +DiffEqBaseUnitfulExt = "Unitful" +DiffEqBaseMPIExt = "MPI" [compat] ArrayInterfaceCore = "0.1.26" @@ -83,10 +89,13 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "SafeTestsets", "Statistics", "Test", "Distributions"] diff --git a/ext/DistributionsExt.jl b/ext/DiffEqBaseDistributionsExt.jl similarity index 84% rename from ext/DistributionsExt.jl rename to ext/DiffEqBaseDistributionsExt.jl index a35bef5d5..e84a5509a 100644 --- a/ext/DistributionsExt.jl +++ b/ext/DiffEqBaseDistributionsExt.jl @@ -1,4 +1,4 @@ -module DistributionsExt +module DiffEqBaseDistributionsExt using Distributions, DiffEqBase diff --git a/ext/GeneralizedGeneratedExt.jl b/ext/DiffEqBaseGeneralizedGeneratedExt.jl similarity index 86% rename from ext/GeneralizedGeneratedExt.jl rename to ext/DiffEqBaseGeneralizedGeneratedExt.jl index 6d507d7dc..30f5b2c99 100644 --- a/ext/GeneralizedGeneratedExt.jl +++ b/ext/DiffEqBaseGeneralizedGeneratedExt.jl @@ -1,4 +1,4 @@ -module GeneralizedGeneratedExt +module DiffEqBaseGeneralizedGeneratedExt using DiffEqBase isdefined(Base, :get_extension) ? (using GeneralizedGenerated) : diff --git a/ext/MPI.jl b/ext/DiffEqBaseMPIExt.jl similarity index 92% rename from ext/MPI.jl rename to ext/DiffEqBaseMPIExt.jl index ffb5737cf..55040e772 100644 --- a/ext/MPI.jl +++ b/ext/DiffEqBaseMPIExt.jl @@ -1,4 +1,4 @@ -module MPIExt +module DiffEqBaseMPIExt import DiffEqBase isdefined(Base, :get_extension) ? (import MPI) : (import ..MPI) diff --git a/ext/MeasurementsExt.jl b/ext/DiffEqBaseMeasurementsExt.jl similarity index 97% rename from ext/MeasurementsExt.jl rename to ext/DiffEqBaseMeasurementsExt.jl index ce35a23aa..43ceace3c 100644 --- a/ext/MeasurementsExt.jl +++ b/ext/DiffEqBaseMeasurementsExt.jl @@ -1,4 +1,4 @@ -module MeasurementsExt +module DiffEqBaseMeasurementsExt using DiffEqBase import DiffEqBase: value diff --git a/ext/MonteCarloMeasurementsExt.jl b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl similarity index 98% rename from ext/MonteCarloMeasurementsExt.jl rename to ext/DiffEqBaseMonteCarloMeasurementsExt.jl index c5513ad52..9871b1209 100644 --- a/ext/MonteCarloMeasurementsExt.jl +++ b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl @@ -1,4 +1,4 @@ -module MonteCarloMeasurementsExt +module DiffEqBaseMonteCarloMeasurementsExt using DiffEqBase import DiffEqBase: value diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl new file mode 100644 index 000000000..aaf3fc137 --- /dev/null +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -0,0 +1,146 @@ +module DiffEqBaseReverseDiffExt + +using DiffEqBase +import DiffEqBase: value +isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff) + +DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value +DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value + +DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 +function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::ReverseDiff.TrackedArray, t0) + u0 +end +function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + u0 +end +function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + u0 +end +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) + +# Support adaptive with non-tracked time +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) + sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t) + abs(DiffEqBase.value(u)) +end + +# Support TrackedReal time, don't drop tracking on the adaptivity there +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, + t::ReverseDiff.TrackedReal) + sqrt(sum(abs2, u) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N}, + t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N}, + t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, + t::ReverseDiff.TrackedReal) + abs(u) +end + +# `ReverseDiff.TrackedArray` +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::ReverseDiff.TrackedArray, args...; kwargs...) + ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::ReverseDiff.TrackedArray, + args...; kwargs...) + ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, p, + args...; kwargs...) + ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +# `AbstractArray{<:ReverseDiff.TrackedReal}` +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; + kwargs...) + DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...; + kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) + DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, + args...; kwargs...) + DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...) +end + +@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG, + proto::ReverseDiff.TrackedArray) + ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto)))) +end +@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, + rand_vec::Array{<:ReverseDiff.TrackedReal + }) + rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec)))) +end +@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, + rand_vec::AbstractArray{ + <:ReverseDiff.TrackedReal + }) + rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec)))) +end + +# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported! +import DiffEqBase: solve_up +ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) + out = DiffEqBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), + ReverseDiff.value(p), + SciMLBase.ReverseDiffOriginator(), args...; kwargs...) + function actual_adjoint(_args...) + original_adjoint = out[2](_args...) + if isempty(args) # alg is missing + tuple(original_adjoint[1:4]..., original_adjoint[6:end]...) + else + original_adjoint + end + end + Array(out[1]), actual_adjoint +end + +end diff --git a/ext/DiffEqBaseTrackerExt.jl b/ext/DiffEqBaseTrackerExt.jl new file mode 100644 index 000000000..eeeaf4696 --- /dev/null +++ b/ext/DiffEqBaseTrackerExt.jl @@ -0,0 +1,95 @@ +module DiffEqBaseTrackerExt + +using DiffEqBase +import DiffEqBase: value +isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker) + +DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T +DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N} +DiffEqBase.value(x::Tracker.TrackedReal) = x.data +DiffEqBase.value(x::Tracker.TrackedArray) = x.data + +DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 +function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, + p::Tracker.TrackedArray, t0) + u0 +end +function DiffEqBase.promote_u0(u0::Tracker.TrackedArray, + p::AbstractArray{<:Tracker.TrackedReal}, t0) + u0 +end +function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, + p::AbstractArray{<:Tracker.TrackedReal}, t0) + u0 +end +DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) +DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0) + +@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y +@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x)) + +# Support adaptive with non-tracked time +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t) + sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u)) + +# Support TrackedReal time, don't drop tracking on the adaptivity there +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, + t::Tracker.TrackedReal) + sqrt(sum(abs2, u) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N}, + t::Tracker.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N}, + t::Tracker.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / + length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u) + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, + p::Tracker.TrackedArray, args...; kwargs...) + Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, p, args...; + kwargs...) + Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, + sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::Tracker.TrackedArray, args...; + kwargs...) + Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +Tracker.@grad function DiffEqBase.solve_up(prob, + sensealg::Union{Nothing, + DiffEqBase.AbstractOverloadingSensitivityAlgorithm + }, + u0, p, args...; + kwargs...) + DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), + SciMLBase.TrackerOriginator(), args...; kwargs...) +end + +end diff --git a/ext/UnitfulExt.jl b/ext/DiffEqBaseUnitfulExt.jl similarity index 97% rename from ext/UnitfulExt.jl rename to ext/DiffEqBaseUnitfulExt.jl index 6cf69719e..ca87ab67d 100644 --- a/ext/UnitfulExt.jl +++ b/ext/DiffEqBaseUnitfulExt.jl @@ -1,4 +1,4 @@ -module UnitfulExt +module DiffEqBaseUnitfulExt using DiffEqBase import DiffEqBase: value diff --git a/ext/DiffEqBaseZygoteExt.jl b/ext/DiffEqBaseZygoteExt.jl new file mode 100644 index 000000000..7ab8fa72a --- /dev/null +++ b/ext/DiffEqBaseZygoteExt.jl @@ -0,0 +1,54 @@ +module DiffEqBaseZygoteExt + +using DiffEqBase +import DiffEqBase: value +isdefined(Base, :get_extension) ? (import Zygote) : (import ..Zygote) + +function ∇tmap(cx, f, args...) + ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + function ∇tmap_internal(Δ) + Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) + Δf_and_args = Zygote.unzip(Δf_and_args_zipped) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + ys, ∇tmap_internal + end +end + +function ∇responsible_map(cx, f, args...) + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), + args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + ys, + function ∇responsible_map_internal(Δ) + # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), + Zygote._tryreverse(SciMLBase.responsible_map, + backs, Δ)...) + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, + Δf_and_args_zipped)) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + end +end + +ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) + ∇tmap(__context__, f, args...) +end + +ZygoteRules.@adjoint function SciMLBase.responsible_map(f, + args::Union{AbstractArray, Tuple + }...) + ∇responsible_map(__context__, f, args...) +end + +end diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index e40d7b133..0814a56dc 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -155,6 +155,6 @@ export SensitivityADPassThrough export KeywordArgError, KeywordArgWarn, KeywordArgSilent if !isdefined(Base, :get_extension) - include("../ext/DistributionsExt.jl") + include("../ext/DiffEqBaseDistributionsExt.jl") end end # module diff --git a/src/init.jl b/src/init.jl index 447ceff31..83d693f03 100644 --- a/src/init.jl +++ b/src/init.jl @@ -14,16 +14,22 @@ function SciMLBase.tmap(args...) error("Zygote must be added to differentiate Zygote? If you see this error, report it.") end -function __init__() - @static if !isdefined(Base, :get_extension) - @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin include("../ext/MeasurementsExt.jl") end +@static if !isdefined(Base, :get_extension) + function __init__() + @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin include("../ext/DiffEqBaseMeasurementsExt.jl") end - @require MonteCarloMeasurements="0987c9cc-fe09-11e8-30f0-b96dd679fdca" begin include("../ext/MonteCarloMeasurementsExt.jl") end + @require MonteCarloMeasurements="0987c9cc-fe09-11e8-30f0-b96dd679fdca" begin include("../ext/DiffEqBaseMonteCarloMeasurementsExt.jl") end - @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" begin include("../ext/UnitfulExt.jl") end + @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" begin include("../ext/DiffEqBaseUnitfulExt.jl") end - @require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin include("../ext/GeneralizedGeneratedExt.jl") end + @require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin include("../ext/DiffEqBaseGeneralizedGeneratedExt.jl") end - @require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" begin include("../ext/MPI.jl") end + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/DiffEqBaseTrackerExt.jl") end + + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/DiffEqBaseReverseDiffExt.jl") end + + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/DiffEqBaseZygoteExt.jl") end + + @require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" begin include("../ext/DiffEqBaseMPIExt.jl") end end end diff --git a/test/downstream/solve_error_handling.jl b/test/downstream/solve_error_handling.jl index 865583e20..e9f0ee780 100644 --- a/test/downstream/solve_error_handling.jl +++ b/test/downstream/solve_error_handling.jl @@ -44,7 +44,7 @@ for u0 in ([0.0, 0.0], nothing) end # Allow empty mass matrix for empty u0 -fmm = ODEFunction((du, u, t)->nothing, mass_matrix = zeros(0, 0)) -prob = ODEProblem(fmm, nothing, (0., 1.)) +fmm = ODEFunction((du, u, t) -> nothing, mass_matrix = zeros(0, 0)) +prob = ODEProblem(fmm, nothing, (0.0, 1.0)) sol = solve(prob, Tsit5()) @test isa(sol, DiffEqBase.ODESolution)