Skip to content

Commit

Permalink
Merge pull request #871 from SciML/weakdep
Browse files Browse the repository at this point in the history
Change weak dep naming scheme and add the AD weak dep overloads
  • Loading branch information
ChrisRackauckas authored Feb 7, 2023
2 parents 7e84b9c + 9ea338b commit 8704130
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 22 deletions.
21 changes: 15 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,21 @@ GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DistributionsExt = "Distributions"
MeasurementsExt = "Measurements"
MPIExt = "MPI"
MonteCarloMeasurementsExt = "MonteCarloMeasurements"
GeneralizedGeneratedExt = "GeneralizedGenerated"
UnitfulExt = "Unitful"
DiffEqBaseZygoteExt = "Zygote"
DiffEqBaseReverseDiffExt = "ReverseDiff"
DiffEqBaseTrackerExt = "Tracker"
DiffEqBaseDistributionsExt = "Distributions"
DiffEqBaseMeasurementsExt = "Measurements"
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseUnitfulExt = "Unitful"
DiffEqBaseMPIExt = "MPI"

[compat]
ArrayInterfaceCore = "0.1.26"
Expand Down Expand Up @@ -83,10 +89,13 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "SafeTestsets", "Statistics", "Test", "Distributions"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module DistributionsExt
module DiffEqBaseDistributionsExt

using Distributions, DiffEqBase

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module GeneralizedGeneratedExt
module DiffEqBaseGeneralizedGeneratedExt

using DiffEqBase
isdefined(Base, :get_extension) ? (using GeneralizedGenerated) :
Expand Down
2 changes: 1 addition & 1 deletion ext/MPI.jl → ext/DiffEqBaseMPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MPIExt
module DiffEqBaseMPIExt

import DiffEqBase
isdefined(Base, :get_extension) ? (import MPI) : (import ..MPI)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MeasurementsExt
module DiffEqBaseMeasurementsExt

using DiffEqBase
import DiffEqBase: value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MonteCarloMeasurementsExt
module DiffEqBaseMonteCarloMeasurementsExt

using DiffEqBase
import DiffEqBase: value
Expand Down
146 changes: 146 additions & 0 deletions ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
module DiffEqBaseReverseDiffExt

using DiffEqBase
import DiffEqBase: value
isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff)

DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value

DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::ReverseDiff.TrackedArray, t0)
u0
end
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)

# Support adaptive with non-tracked time
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t)
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t)
abs(DiffEqBase.value(u))
end

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray,
t::ReverseDiff.TrackedReal)
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal,
t::ReverseDiff.TrackedReal)
abs(u)
end

# `ReverseDiff.TrackedArray`
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::ReverseDiff.TrackedArray, args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::ReverseDiff.TrackedArray,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray, p,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

# `AbstractArray{<:ReverseDiff.TrackedReal}`
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...;
kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0,
p::AbstractArray{<:ReverseDiff.TrackedReal},
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
end

@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,
proto::ReverseDiff.TrackedArray)
ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto))))
end
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,
rand_vec::Array{<:ReverseDiff.TrackedReal
})
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
end
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,
rand_vec::AbstractArray{
<:ReverseDiff.TrackedReal
})
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
end

# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported!
import DiffEqBase: solve_up
ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...)
out = DiffEqBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0),
ReverseDiff.value(p),
SciMLBase.ReverseDiffOriginator(), args...; kwargs...)
function actual_adjoint(_args...)
original_adjoint = out[2](_args...)
if isempty(args) # alg is missing
tuple(original_adjoint[1:4]..., original_adjoint[6:end]...)
else
original_adjoint
end
end
Array(out[1]), actual_adjoint
end

end
95 changes: 95 additions & 0 deletions ext/DiffEqBaseTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
module DiffEqBaseTrackerExt

using DiffEqBase
import DiffEqBase: value
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)

DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T
DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
DiffEqBase.value(x::Tracker.TrackedArray) = x.data

DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
p::Tracker.TrackedArray, t0)
u0
end
function DiffEqBase.promote_u0(u0::Tracker.TrackedArray,
p::AbstractArray{<:Tracker.TrackedReal}, t0)
u0
end
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
p::AbstractArray{<:Tracker.TrackedReal}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0)
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0)

@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))

# Support adaptive with non-tracked time
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t)
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u))

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray,
t::Tracker.TrackedReal)
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
t::Tracker.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
t::Tracker.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u)

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::Tracker.TrackedArray,
p::Tracker.TrackedArray, args...; kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::Tracker.TrackedArray, p, args...;
kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::Tracker.TrackedArray, args...;
kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

Tracker.@grad function DiffEqBase.solve_up(prob,
sensealg::Union{Nothing,
DiffEqBase.AbstractOverloadingSensitivityAlgorithm
},
u0, p, args...;
kwargs...)
DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p),
SciMLBase.TrackerOriginator(), args...; kwargs...)
end

end
2 changes: 1 addition & 1 deletion ext/UnitfulExt.jl → ext/DiffEqBaseUnitfulExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module UnitfulExt
module DiffEqBaseUnitfulExt

using DiffEqBase
import DiffEqBase: value
Expand Down
54 changes: 54 additions & 0 deletions ext/DiffEqBaseZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
module DiffEqBaseZygoteExt

using DiffEqBase
import DiffEqBase: value
isdefined(Base, :get_extension) ? (import Zygote) : (import ..Zygote)

function ∇tmap(cx, f, args...)
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
function ∇tmap_internal(Δ)
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
ys, ∇tmap_internal
end
end

function ∇responsible_map(cx, f, args...)
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
ys,
function ∇responsible_map_internal(Δ)
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
Zygote._tryreverse(SciMLBase.responsible_map,
backs, Δ)...)
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
Δf_and_args_zipped))
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end

ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
∇tmap(__context__, f, args...)
end

ZygoteRules.@adjoint function SciMLBase.responsible_map(f,
args::Union{AbstractArray, Tuple
}...)
∇responsible_map(__context__, f, args...)
end

end
2 changes: 1 addition & 1 deletion src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,6 @@ export SensitivityADPassThrough
export KeywordArgError, KeywordArgWarn, KeywordArgSilent

if !isdefined(Base, :get_extension)
include("../ext/DistributionsExt.jl")
include("../ext/DiffEqBaseDistributionsExt.jl")
end
end # module
Loading

0 comments on commit 8704130

Please sign in to comment.