Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Remove AD weak deps" #785

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ abstract type TransformedFunction end

import SciMLBase: unwrapped_f

import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm,
AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm,
AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm

include("hasbranching.jl")
include("sensitivity_algorithms.jl")
include("derivative_wrappers.jl")
Expand All @@ -55,6 +51,11 @@ include("steadystate_adjoint.jl")
include("sde_tools.jl")
include("staticarrays.jl")

# AD Extensions
include("reversediff.jl")
include("tracker.jl")
include("zygote.jl")

export extract_local_sensitivities

export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityFunction,
Expand Down
151 changes: 151 additions & 0 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Piracy that used to be requires, allowing ReverseDiff.jl to be specialized for SciML

DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value

DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::ReverseDiff.TrackedArray, t0)
u0
end
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)

# Support adaptive with non-tracked time
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t)
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t)
abs(DiffEqBase.value(u))
end

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray,
t::ReverseDiff.TrackedReal)
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal,
t::ReverseDiff.TrackedReal)
abs(u)
end

# `ReverseDiff.TrackedArray`
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::ReverseDiff.TrackedArray, args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::ReverseDiff.TrackedArray,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray, p,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

# `AbstractArray{<:ReverseDiff.TrackedReal}`
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...;
kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0,
p::AbstractArray{<:ReverseDiff.TrackedReal},
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
end

@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,
proto::ReverseDiff.TrackedArray)
ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto))))
end
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,
rand_vec::Array{<:ReverseDiff.TrackedReal
})
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
end
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,
rand_vec::AbstractArray{
<:ReverseDiff.TrackedReal
})
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
end

# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported!
import DiffEqBase: solve_up
ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...)
out = DiffEqBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0),
ReverseDiff.value(p),
SciMLBase.ReverseDiffOriginator(), args...; kwargs...)
function actual_adjoint(_args...)
original_adjoint = out[2](_args...)
if isempty(args) # alg is missing
tuple(original_adjoint[1:4]..., original_adjoint[6:end]...)
else
original_adjoint
end
end
Array(out[1]), actual_adjoint
end

# PreallocationTools https://github.com/SciML/PreallocationTools.jl/issues/39
function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray)
s = b.sizemap(size(u)) # required buffer size
T = ReverseDiff.TrackedArray
buf = get!(b.bufs, (T, s)) do
# declare type since b.bufs dictionary is untyped
similar(u, s)
end
return buf
end
12 changes: 12 additions & 0 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@ function SensitivityAlg(args...; kwargs...)
@error("The SensitivtyAlg choice mechanism was completely overhauled. Please consult the local sensitivity documentation for more information")
end

# This is to easily ignore AD passthrough
abstract type AbstractOverloadingSensitivityAlgorithm{CS, AD, FDT} <:
DiffEqBase.AbstractSensitivityAlgorithm{CS, AD, FDT} end
abstract type AbstractForwardSensitivityAlgorithm{CS, AD, FDT} <:
AbstractOverloadingSensitivityAlgorithm{CS, AD, FDT} end
abstract type AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} <:
AbstractOverloadingSensitivityAlgorithm{CS, AD, FDT} end
abstract type AbstractSecondOrderSensitivityAlgorithm{CS, AD, FDT} <:
AbstractOverloadingSensitivityAlgorithm{CS, AD, FDT} end
abstract type AbstractShadowingSensitivityAlgorithm{CS, AD, FDT} <:
AbstractOverloadingSensitivityAlgorithm{CS, AD, FDT} end

"""
```julia
ForwardSensitivity{CS, AD, FDT} <: AbstractForwardSensitivityAlgorithm{CS, AD, FDT}
Expand Down
101 changes: 101 additions & 0 deletions src/tracker.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Piracy that used to be requires, allowing Tracker.jl to be specialized for SciML

function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
@inbounds for i in eachindex(a)
b[i] = copy(a[i])
end
end

DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T
DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
DiffEqBase.value(x::Tracker.TrackedArray) = x.data

DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
p::Tracker.TrackedArray, t0)
u0
end
function DiffEqBase.promote_u0(u0::Tracker.TrackedArray,
p::AbstractArray{<:Tracker.TrackedReal}, t0)
u0
end
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
p::AbstractArray{<:Tracker.TrackedReal}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0)
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0)

@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))

# Support adaptive with non-tracked time
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t)
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u))

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray,
t::Tracker.TrackedReal)
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
t::Tracker.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
t::Tracker.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u)

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::Tracker.TrackedArray,
p::Tracker.TrackedArray, args...; kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::Tracker.TrackedArray, p, args...;
kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::Tracker.TrackedArray, args...;
kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

Tracker.@grad function DiffEqBase.solve_up(prob,
sensealg::Union{Nothing,
AbstractOverloadingSensitivityAlgorithm
},
u0, p, args...;
kwargs...)
DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p),
SciMLBase.TrackerOriginator(), args...; kwargs...)
end
48 changes: 48 additions & 0 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Piracy that used to be requires, allowing Zyogote.jl to be specialized for SciML

function ∇tmap(cx, f, args...)
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
function ∇tmap_internal(Δ)
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
ys, ∇tmap_internal
end
end

function ∇responsible_map(cx, f, args...)
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
ys,
function ∇responsible_map_internal(Δ)
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
Zygote._tryreverse(SciMLBase.responsible_map,
backs, Δ)...)
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
Δf_and_args_zipped))
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end

ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
∇tmap(__context__, f, args...)
end

ZygoteRules.@adjoint function SciMLBase.responsible_map(f,
args::Union{AbstractArray, Tuple
}...)
∇responsible_map(__context__, f, args...)
end