From 6918b0e83dc7c76aacb374abf30c6e7b94d69746 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 13 Dec 2021 15:57:14 -0500 Subject: [PATCH 1/8] use foldl for Chain --- src/layers/basic.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 42310d0b7c..ae135d8326 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,10 +44,7 @@ end @functor Chain -applychain(::Tuple{}, x) = x -applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) - -(c::Chain)(x) = applychain(Tuple(c.layers), x) +(c::Chain)(x) = foldl((y,f) -> f(y), (x, c.layers...)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = From 00ba1242fd7d3a101afd121aa75696bd0a31cdfc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 2 Jan 2022 03:17:40 -0500 Subject: [PATCH 2/8] use foldl less often --- src/layers/basic.jl | 12 +++++++++++- src/utils.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ae135d8326..754323177b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,7 +44,17 @@ end @functor Chain -(c::Chain)(x) = foldl((y,f) -> f(y), (x, c.layers...)) +function (c::Chain)(x) + if order() < 2 + foldl((y,f) -> f(y), (x, c.layers...)) + else + # This hand-written foldl causes high latency + applychain(Tuple(c.layers), x) + end +end + +applychain(::Tuple{}, x) = x +applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = diff --git a/src/utils.jl b/src/utils.jl index 035798b5c0..d25da61ec2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -914,3 +914,31 @@ function plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6) return patience(is_plateau, width) end + + +""" + order() + +Returns `1` inside a call to `Zygote.gradient`, `2` inside nested such calls. + +# Examples +```jldoctest; setup = :(using Flux, Zygote) +julia> Flux.order() +0 + +julia> gradient(x -> (@show(Flux.order()); x^3), 1) +Flux.order() = 1 +(3.0,) + +julia> gradient(y -> gradient(x -> (@show(Flux.order()); x^3), y)[1], 1) +Flux.order() = 2 +(6.0,) + +julia> Zygote.hessian(x -> (@show(Flux.order()); x^3), 1) # uses ForwardDiff over Zygote +Flux.order() = 1 +6 +``` +""" +order(::Val{n} = Val(0)) where {n} = n + +Zygote.@adjoint order(::Val{n}) where {n} = order(Val(n+1)), Returns(nothing) From a9bbb0c0c649d3e849c468fccab865233abfaa07 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 2 Jan 2022 03:17:52 -0500 Subject: [PATCH 3/8] second derivative tests --- test/layers/basic.jl | 15 +++++++++++++++ test/runtests.jl | 1 + 2 files changed, 16 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 968ddd506f..906a713c17 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -297,3 +297,18 @@ import Flux: activations @test_throws DimensionMismatch m(OneHotVector(3, 1000)) end end + +@testset "second derivatives" begin + m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + @test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3]) + + # NNlib's softmax gradient writes in-place + m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax) + @test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3]) + + # https://github.com/FluxML/NNlib.jl/issues/362 + m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2)) + x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3) + @test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) +end + diff --git a/test/runtests.jl b/test/runtests.jl index a6abd609d2..f75174e447 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle +using Zygote using CUDA Random.seed!(0) From 585043dbd5784dd1538fd8acc7147556da3bd49c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 19:03:53 -0500 Subject: [PATCH 4/8] Revert "use foldl less often" This reverts commit a74f86d312c30b26ad4a49483216777bf4a871ff. --- src/layers/basic.jl | 12 +----------- src/utils.jl | 28 ---------------------------- 2 files changed, 1 insertion(+), 39 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 754323177b..ae135d8326 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,17 +44,7 @@ end @functor Chain -function (c::Chain)(x) - if order() < 2 - foldl((y,f) -> f(y), (x, c.layers...)) - else - # This hand-written foldl causes high latency - applychain(Tuple(c.layers), x) - end -end - -applychain(::Tuple{}, x) = x -applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) +(c::Chain)(x) = foldl((y,f) -> f(y), (x, c.layers...)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = diff --git a/src/utils.jl b/src/utils.jl index d25da61ec2..035798b5c0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -914,31 +914,3 @@ function plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6) return patience(is_plateau, width) end - - -""" - order() - -Returns `1` inside a call to `Zygote.gradient`, `2` inside nested such calls. - -# Examples -```jldoctest; setup = :(using Flux, Zygote) -julia> Flux.order() -0 - -julia> gradient(x -> (@show(Flux.order()); x^3), 1) -Flux.order() = 1 -(3.0,) - -julia> gradient(y -> gradient(x -> (@show(Flux.order()); x^3), y)[1], 1) -Flux.order() = 2 -(6.0,) - -julia> Zygote.hessian(x -> (@show(Flux.order()); x^3), 1) # uses ForwardDiff over Zygote -Flux.order() = 1 -6 -``` -""" -order(::Val{n} = Val(0)) where {n} = n - -Zygote.@adjoint order(::Val{n}) where {n} = order(Val(n+1)), Returns(nothing) From fcb09da119ec90e7cca5de547b7ca21275da089c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 20:02:39 -0500 Subject: [PATCH 5/8] replace foldl with generated expression --- src/layers/basic.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ae135d8326..ca11c45daa 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,7 +44,15 @@ end @functor Chain -(c::Chain)(x) = foldl((y,f) -> f(y), (x, c.layers...)) +(c::Chain)(x) = applychain(c.layers, x) + +@generated function applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N} + symbols = vcat(:x, [gensym() for _ in 1:N]) + calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N] + Expr(:block, calls...) +end + +applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = From d99d7ac64d18221e8674e6682fe86df980477cfa Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 20:13:06 -0500 Subject: [PATCH 6/8] allow unstable Chain{Vector} too --- src/layers/basic.jl | 14 +++++++++++++- src/layers/show.jl | 11 ++++++----- test/layers/basic.jl | 25 +++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ca11c45daa..9ed50cd02a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -27,8 +27,12 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` + +For large models, there is a special type-unstable path which can reduce compilation +times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. +This feature is somewhat experimental, beware! """ -struct Chain{T<:Union{Tuple, NamedTuple}} +struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} layers::T end @@ -54,6 +58,13 @@ end applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x) +function applychain(layers::AbstractVector, x) # type-unstable path, helps compile times + for f in layers + x = f(x) + end + x +end + Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])) @@ -65,6 +76,7 @@ function Base.show(io::IO, c::Chain) end _show_layers(io, layers::Tuple) = join(io, layers, ", ") _show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]")) # This is a temporary and naive implementation # it might be replaced in the future for better performance diff --git a/src/layers/show.jl b/src/layers/show.jl index 85faec3c59..288e97acd6 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,11 +14,12 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) + pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ") ") children = _show_children(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") + println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) # then we insert names -- can this be done more generically? for k in Base.keys(obj) @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end if indent == 0 # i.e. this is the outermost container - print(io, ")") + print(io, post) _big_finale(io, obj) else - println(io, " "^indent, "),") + println(io, " "^indent, post) end end end @@ -90,12 +91,12 @@ function _big_finale(io::IO, m) noncnt = _childarray_sum(_->1, m) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) - printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) + printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) println(io, pars, " parameters,") printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) print(io, bytes, ".") else - printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black) + printstyled(io, " "^18, "# Total: ", length(ps), " arrays, "; color=:light_black) print(io, pars, " parameters, ", bytes, ".") end end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 906a713c17..9852b1c1b7 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -29,6 +29,8 @@ import Flux: activations @test m == fmap(identity, m) # does not forget names @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name + + @test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers end @testset "Activations" begin @@ -302,6 +304,10 @@ end m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) @test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3]) + m1v = Chain([m1[1], m1[2]]) # vector of layers + @test Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_dual(sum∘m1, [1,2,3]) + @test_broken Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1v, [1,2,3]) + # NNlib's softmax gradient writes in-place m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax) @test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3]) @@ -312,3 +318,22 @@ end @test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) end +@testset "gradients of Chain{Vector}" begin + m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + m1v = Chain([m1[1], m1[2]]) + @test sum(length, params(m1)) == sum(length, params(m1v)) + + x1 = randn(Float32,3,5) + @test m1(x1) ≈ m1v(x1) + + y1 = rand(Bool,2,5) + g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1)) + g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v)) + @test g1[m1[1].weight] ≈ g1v[m1v[1].weight] + @test g1[m1[2].bias] ≈ g1v[m1v[2].bias] + + @test Flux.destructure(m1)[1] ≈ Flux.destructure(m1v)[1] + z1 = rand(22); + @test Flux.destructure(m1)[2](z1)[1].weight ≈ Flux.destructure(m1v)[2](z1)[1].weight + # Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2] +end From 4dfd5510ca5eac8d05d573f43b70dc2687bdcff4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Feb 2022 14:20:22 -0500 Subject: [PATCH 7/8] trailing comma --- src/layers/show.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 288e97acd6..7772b2f49b 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,7 +14,7 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) - pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ") ") + pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) @@ -36,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end if indent == 0 # i.e. this is the outermost container - print(io, post) + print(io, rpad(post, 2)) _big_finale(io, obj) else - println(io, " "^indent, post) + println(io, " "^indent, post, ",") end end end From f60da1af12ab267c6c8895e4ee535529a6a4c06e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Feb 2022 14:36:29 -0500 Subject: [PATCH 8/8] fixup --- src/layers/show.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 7772b2f49b..a37af36065 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -102,7 +102,7 @@ function _big_finale(io::IO, m) end end -_childarray_sum(f, x::AbstractArray) = f(x) +_childarray_sum(f, x::AbstractArray{<:Number}) = f(x) _childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) # utility functions