diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1bc591785a..3e22895e82 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -27,8 +27,12 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` + +For large models, there is a special type-unstable path which can reduce compilation +times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. +This feature is somewhat experimental, beware! """ -struct Chain{T<:Union{Tuple, NamedTuple}} +struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} layers::T end @@ -44,10 +48,22 @@ end @functor Chain -applychain(::Tuple{}, x) = x -applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) +(c::Chain)(x) = applychain(c.layers, x) + +@generated function applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N} + symbols = vcat(:x, [gensym() for _ in 1:N]) + calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N] + Expr(:block, calls...) +end + +applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x) -(c::Chain)(x) = applychain(Tuple(c.layers), x) +function applychain(layers::AbstractVector, x) # type-unstable path, helps compile times + for f in layers + x = f(x) + end + x +end Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = @@ -60,6 +76,7 @@ function Base.show(io::IO, c::Chain) end _show_layers(io, layers::Tuple) = join(io, layers, ", ") _show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]")) # This is a temporary and naive implementation # it might be replaced in the future for better performance diff --git a/src/layers/show.jl b/src/layers/show.jl index 85faec3c59..a37af36065 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,11 +14,12 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) + pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") + println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) # then we insert names -- can this be done more generically? for k in Base.keys(obj) @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end if indent == 0 # i.e. this is the outermost container - print(io, ")") + print(io, rpad(post, 2)) _big_finale(io, obj) else - println(io, " "^indent, "),") + println(io, " "^indent, post, ",") end end end @@ -90,18 +91,18 @@ function _big_finale(io::IO, m) noncnt = _childarray_sum(_->1, m) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) - printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) + printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) println(io, pars, " parameters,") printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) print(io, bytes, ".") else - printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black) + printstyled(io, " "^18, "# Total: ", length(ps), " arrays, "; color=:light_black) print(io, pars, " parameters, ", bytes, ".") end end end -_childarray_sum(f, x::AbstractArray) = f(x) +_childarray_sum(f, x::AbstractArray{<:Number}) = f(x) _childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) # utility functions diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 291d04f304..11f95f023a 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -29,6 +29,8 @@ import Flux: activations @test m == fmap(identity, m) # does not forget names @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name + + @test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers end @testset "Activations" begin @@ -297,3 +299,41 @@ import Flux: activations @test_throws DimensionMismatch m(OneHotVector(3, 1000)) end end + +@testset "second derivatives" begin + m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + @test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3]) + + m1v = Chain([m1[1], m1[2]]) # vector of layers + @test Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_dual(sum∘m1, [1,2,3]) + @test_broken Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1v, [1,2,3]) + + # NNlib's softmax gradient writes in-place + m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax) + @test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3]) + + # https://github.com/FluxML/NNlib.jl/issues/362 + m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2)) + x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3) + @test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) +end + +@testset "gradients of Chain{Vector}" begin + m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + m1v = Chain([m1[1], m1[2]]) + @test sum(length, params(m1)) == sum(length, params(m1v)) + + x1 = randn(Float32,3,5) + @test m1(x1) ≈ m1v(x1) + + y1 = rand(Bool,2,5) + g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1)) + g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v)) + @test g1[m1[1].weight] ≈ g1v[m1v[1].weight] + @test g1[m1[2].bias] ≈ g1v[m1v[2].bias] + + @test Flux.destructure(m1)[1] ≈ Flux.destructure(m1v)[1] + z1 = rand(22); + @test Flux.destructure(m1)[2](z1)[1].weight ≈ Flux.destructure(m1v)[2](z1)[1].weight + # Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2] +end diff --git a/test/runtests.jl b/test/runtests.jl index 3286e881bd..706f126451 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle +using Zygote using CUDA Random.seed!(0)