Skip to content

Commit

Permalink
Replace unrolled foldl used to evaluate Chain with a better one (#…
Browse files Browse the repository at this point in the history
…1809)

* use foldl for Chain

* use foldl less often

* second derivative tests

* Revert "use foldl less often"

This reverts commit a74f86d.

* replace foldl with generated expression

* allow unstable Chain{Vector} too

* trailing comma

* fixup
  • Loading branch information
mcabbott authored Feb 5, 2022
1 parent 4a3483e commit 7b56813
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 10 deletions.
25 changes: 21 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
"""
struct Chain{T<:Union{Tuple, NamedTuple}}
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}}
layers::T
end

Expand All @@ -44,10 +48,22 @@ end

@functor Chain

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
(c::Chain)(x) = applychain(c.layers, x)

@generated function applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
symbols = vcat(:x, [gensym() for _ in 1:N])
calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N]
Expr(:block, calls...)
end

applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x)

(c::Chain)(x) = applychain(Tuple(c.layers), x)
function applychain(layers::AbstractVector, x) # type-unstable path, helps compile times
for f in layers
x = f(x)
end
x
end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Expand All @@ -60,6 +76,7 @@ function Base.show(io::IO, c::Chain)
end
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]"))

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Expand Down
13 changes: 7 additions & 6 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
children = _show_children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
Expand All @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end
if indent == 0 # i.e. this is the outermost container
print(io, ")")
print(io, rpad(post, 2))
_big_finale(io, obj)
else
println(io, " "^indent, "),")
println(io, " "^indent, post, ",")
end
end
end
Expand Down Expand Up @@ -90,18 +91,18 @@ function _big_finale(io::IO, m)
noncnt = _childarray_sum(_->1, m) - length(ps)
if noncnt > 0
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
println(io, pars, " parameters,")
printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black)
print(io, bytes, ".")
else
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black)
printstyled(io, " "^18, "# Total: ", length(ps), " arrays, "; color=:light_black)
print(io, pars, " parameters, ", bytes, ".")
end
end
end

_childarray_sum(f, x::AbstractArray) = f(x)
_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))

# utility functions
Expand Down
40 changes: 40 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import Flux: activations
@test m == fmap(identity, m) # does not forget names

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name

@test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers
end

@testset "Activations" begin
Expand Down Expand Up @@ -297,3 +299,41 @@ import Flux: activations
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end
end

@testset "second derivatives" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
@test Zygote.hessian_dual(summ1, [1,2,3]) Zygote.hessian_reverse(summ1, [1,2,3])

m1v = Chain([m1[1], m1[2]]) # vector of layers
@test Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_dual(summ1, [1,2,3])
@test_broken Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_reverse(summ1v, [1,2,3])

# NNlib's softmax gradient writes in-place
m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax)
@test_broken Zygote.hessian_dual(summ2, [1,2,3]) Zygote.hessian_reverse(summ2, [1,2,3])

# https://github.com/FluxML/NNlib.jl/issues/362
m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2))
x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3)
@test_broken Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
end

@testset "gradients of Chain{Vector}" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
m1v = Chain([m1[1], m1[2]])
@test sum(length, params(m1)) == sum(length, params(m1v))

x1 = randn(Float32,3,5)
@test m1(x1) m1v(x1)

y1 = rand(Bool,2,5)
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
@test g1[m1[1].weight] g1v[m1v[1].weight]
@test g1[m1[2].bias] g1v[m1v[2].bias]

@test Flux.destructure(m1)[1] Flux.destructure(m1v)[1]
z1 = rand(22);
@test Flux.destructure(m1)[2](z1)[1].weight Flux.destructure(m1v)[2](z1)[1].weight
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Flux: params
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
using Zygote
using CUDA

Random.seed!(0)
Expand Down

0 comments on commit 7b56813

Please sign in to comment.