Skip to content

Commit

Permalink
fixes + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 18, 2022
1 parent 5877d36 commit 7f06b80
Show file tree
Hide file tree
Showing 8 changed files with 587 additions and 124 deletions.
13 changes: 13 additions & 0 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ end
@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer)
@ChainRules.non_differentiable Base.throw(err)
@ChainRules.non_differentiable Core.Compiler.return_type(args...)

ChainRulesCore.canonicalize(::NoTangent) = NoTangent()

# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
Expand All @@ -259,3 +260,15 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
end

Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581

# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing

# For gradient(pow_simd, 2, 3)[1] in zygote_features.jl
ChainRulesCore.@non_differentiable Base.SimdLoop.simd_inner_length(::Any, ::Any)

# This allows fill!(similar([1,2,3], ZeroTangent), false)
function Base.convert(::Type{ZeroTangent}, x::Number)
iszero(x) || throw(InexactError(:convert, ZeroTangent, x))
ZeroTangent()
end
6 changes: 6 additions & 0 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing
_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()

function accum(x::Tangent{T}, y::Tuple) where {T<:Tuple}
# @warn "gradient is both a Tangent and a Tuple" x y
_tangent(T, accum(backing(x), y))
end
accum(x::Tuple, y::Tangent{<:Tuple}) = accum(y, x)
48 changes: 44 additions & 4 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@

# This file has integration tests for some rules defined in ChainRules.jl,
# especially those which aim to support higher derivatives, as properly
# testing those is difficult.
# testing those is difficult. Organised according to the files in CR.jl.

using Diffractor, ForwardDiff, ChainRulesCore
using Test, LinearAlgebra

using Test: Threw, eval_test

using Diffractor, ChainRulesCore, ForwardDiff

#####
##### Base/array.jl
Expand All @@ -13,7 +17,6 @@ using Diffractor, ChainRulesCore, ForwardDiff




#####
##### Base/arraymath.jl
#####
Expand All @@ -33,21 +36,58 @@ using Diffractor, ChainRulesCore, ForwardDiff
##### Base/indexing.jl
#####

@testset "getindex, first" begin
@test_broken gradient(x -> sum(abs2, gradient(first, x)[1]), [1,2,3])[1] == [0, 0, 0] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
@test_broken gradient(x -> sum(abs2, gradient(sqrtfirst, x)[1]), [1,2,3])[1] [-0.25, 0, 0] # error() in perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{2}}, args::Any)
@test gradient(x -> sum(abs2, gradient(x -> x[1]^2, x)[1]), [1,2,3])[1] == [8, 0, 0]
@test_broken gradient(x -> sum(abs2, gradient(x -> sum(x[1:2])^2, x)[1]), [1,2,3])[1] == [48, 0, 0] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
end


@testset "eachcol etc" begin
@test gradient(m -> sum(prod, eachcol(m)), [1 2 3; 4 5 6])[1] == [4 5 6; 1 2 3]
@test gradient(m -> sum(first, eachcol(m)), [1 2 3; 4 5 6])[1] == [1 1 1; 0 0 0]
@test gradient(m -> sum(first(eachcol(m))), [1 2 3; 4 5 6])[1] == [1 0 0; 1 0 0]
@test_skip gradient(x -> sum(sin, gradient(m -> sum(first(eachcol(m))), x)[1]), [1 2 3; 4 5 6])[1] # MethodError: no method matching one(::Base.OneTo{Int64}), unzip_broadcast, split_bc_forwards
end

#####
##### Base/mapreduce.jl
#####

@testset "sum" begin
@test gradient(x -> sum(abs2, gradient(sum, x)[1]), [1,2,3])[1] == [0,0,0]
@test gradient(x -> sum(abs2, gradient(x -> sum(abs2, x), x)[1]), [1,2,3])[1] == [8,16,24]

@test gradient(x -> sum(abs2, gradient(sum, x .^ 2)[1]), [1,2,3])[1] == [0,0,0]
@test gradient(x -> sum(abs2, gradient(sum, x .^ 3)[1]), [1,2,3])[1] == [0,0,0]
end

@testset "foldl" begin

@test gradient(x -> foldl(*, x), [1,2,3,4])[1] == [24.0, 12.0, 8.0, 6.0]
@test gradient(x -> foldl(*, x; init=5), [1,2,3,4])[1] == [120.0, 60.0, 40.0, 30.0]
@test gradient(x -> foldr(*, x), [1,2,3,4])[1] == [24, 12, 8, 6]

@test gradient(x -> foldl(*, x), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(24.0, 12.0, 8.0, 6.0)
@test_broken gradient(x -> foldl(*, x; init=5), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(120.0, 60.0, 40.0, 30.0) # does not return a Tangent
@test gradient(x -> foldl(*, x; init=5), (1,2,3,4)) |> only |> Tuple == (120.0, 60.0, 40.0, 30.0)
@test_broken gradient(x -> foldr(*, x), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(24, 12, 8, 6)
@test gradient(x -> foldr(*, x), (1,2,3,4)) |> only |> Tuple == (24, 12, 8, 6)

end


#####
##### LinearAlgebra/dense.jl
#####


@testset "dot" begin

@test gradient(x -> dot(x, [1,2,3])^2, [4,5,6])[1] == [64,128,192]
@test_broken gradient(x -> sum(gradient(x -> dot(x, [1,2,3])^2, x)[1]), [4,5,6])[1] == [12,24,36] # MethodError: no method matching +(::Tuple{Tangent{ChainRules.var

end


#####
Expand Down
23 changes: 15 additions & 8 deletions test/diffractor_01.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# The rest of this file is unchanged, except the very end,
# but IMO we should move these tests to a new file.
# This file has tests written specifically for Diffractor v0.1,
# which were in runtests.jl before PR 73 moved them all.

# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
using Test

using Diffractor
using Diffractor: ∂⃖, DiffractorRuleConfig

using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
using ChainRules
using ChainRulesCore
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
using Symbolics
using LinearAlgebra

using LinearAlgebra

# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
const fwd = Diffractor.PrimeDerivativeFwd
const bwd = Diffractor.PrimeDerivativeBack

Expand Down Expand Up @@ -48,8 +51,10 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
@test isequal(simplify(x8), simplify((η +*ζ) +*ϵ) +*+*β))))*exp(ω)))

# Minimal 2-nd order forward smoke test
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
let var"'" = Diffractor.PrimeDerivativeBack
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
end

function simple_control_flow(b, x)
if b
Expand Down Expand Up @@ -269,7 +274,7 @@ end
@testset "broadcast, 2nd order" begin
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order
@test gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12]

@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
Expand All @@ -283,3 +288,5 @@ end
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
end

# Issue 67, due to https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
@test gradient(identitysqrt, 4.0) == (0.25,)
Loading

0 comments on commit 7f06b80

Please sign in to comment.