Skip to content

Commit

Permalink
Use broadcasting rules from ChainRules (#89)
Browse files Browse the repository at this point in the history
* remove reverse mode broadcasting rules

* fix some other rules

* tests

* update tests, CR version

* delete commented lines

* rm comment
  • Loading branch information
mcabbott authored Sep 17, 2022
1 parent bd4da5f commit 55d2871
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 72 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
ChainRules = "1.5"
ChainRulesCore = "1.2"
ChainRules = "1.44.6"
ChainRulesCore = "1.15.3"
Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
Expand Down
34 changes: 12 additions & 22 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
end

function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...)
xs[i...], ∇getindex(xs, i)
end

Expand Down Expand Up @@ -150,12 +150,6 @@ end

ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()

# Skip AD'ing through the axis computation
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
return Base.Broadcast.instantiate(bc), Δ->begin
Core.tuple(NoTangent(), Δ)
end
end


using StaticArrays
Expand Down Expand Up @@ -187,10 +181,6 @@ end

@ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T)

function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
getindex(A, args...), getindex(∂A, args...)
end

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
end
Expand Down Expand Up @@ -225,27 +215,28 @@ struct BackMap{T}
f::T
end
(f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...)
back_apply(x, y) = x(y)
back_apply_zero(x) = x(Zero())
back_apply(x, y) = x(y) # this is just |> with arguments reversed
back_apply_zero(x) = x(Zero()) # Zero is not defined

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
a, b = unzip_tuple(map(BackMap(f), args))
function back(Δ)
function map_back(Δ)
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
(NoTangent(), sum(fs), xs)
end
function back::ZeroTangent)
(fs, xs) = unzip_tuple(map(back_apply_zero, b))
(NoTangent(), sum(fs), xs)
end
a, back
map_back::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
a, map_back
end

ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent())

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
a, b = unzip_tuple(ntuple(BackMap(f), n))
a, function (Δ)
function ntuple_back(Δ)
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
end
ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
a, ntuple_back
end

function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
Expand All @@ -267,5 +258,4 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
val, Δ->(NoTangent(), NoTangent(), Δ)
end

Base.real(z::ZeroTangent) = z # TODO should be in CRC
Base.real(z::NoTangent) = z
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
1 change: 1 addition & 0 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing

_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()
43 changes: 0 additions & 43 deletions src/stage1/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,3 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
end
return r
end

# Broadcast over one element is just map
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
∂⃖ₙ(map, f, a)
end

# The below is from Zygote: TODO: DO we want to do something better here?

accum_sum(xs::Nothing; dims = :) = NoTangent()
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::Number; dims = :) = xs

# https://github.com/FluxML/Zygote.jl/issues/594
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))

unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ?:
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))

unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)

unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()

const Numeric = Union{Number, AbstractArray{<:Number, N} where N}

function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::Numeric...)
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
end

ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end

ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
5 changes: 3 additions & 2 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
end

# TODO: Temporary - make better
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) where {N}
getindex(a, inds...), let
EvenOddOdd{1, c_order(N)}(
(@Base.constprop :aggressive Δ->begin
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
BB = zero(a)
BB[inds...] = Δ
BB[inds...] = unthunk(Δ)
(NoTangent(), BB, map(x->NoTangent(), inds)...)
end),
(@Base.constprop :aggressive (_, Δ, _)->begin
Expand All @@ -334,6 +334,7 @@ struct tuple_back{M}; end
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
(::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...)
(t::tuple_back)(Δ::AbstractThunk) = t(unthunk(Δ))

function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
Core.tuple(args...),
Expand Down
68 changes: 65 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
@test @inferred(sin'(1.0)) == cos(1.0)
@test @inferred(sin''(1.0)) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)
@test sin''''(1.0) == sin(1.0) broken = VERSION >= v"1.8"
@test sin'''''(1.0) == cos(1.0) broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) broken = VERSION >= v"1.8"
@test sin''''(1.0) == sin(1.0)
@test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8"

f_getfield(x) = getfield((x,), 1)
@test f_getfield'(1) == 1
Expand Down Expand Up @@ -219,6 +219,68 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
@test z45 2.0
@test delta45 1.0

# PR #82 - getindex on non-numeric arrays
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}

@testset "broadcast" begin
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)

@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
exp_log(x) = exp(log(x))
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure

@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] [12, 12, 12] # must not take the * fast path

@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],)
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule

@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)

@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent())
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero

tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5]))
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
@test tup_adj[2] isa Transpose
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal

@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
end

@testset "broadcast, 2nd order" begin
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order

@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]

@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]

@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
end

# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
#include("pinn.jl")

Expand Down

0 comments on commit 55d2871

Please sign in to comment.