Skip to content

Commit

Permalink
allow unstable Chain{Vector} too
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 5, 2022
1 parent 3c883c1 commit 9284bd7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
14 changes: 14 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
"""
struct Chain{T}
layers::T
Expand All @@ -36,6 +40,7 @@ struct Chain{T}
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
Chain(xs::AbstractVector) = new{typeof(xs)}(xs) # unstable path, to help compile times
end

@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Expand All @@ -53,9 +58,17 @@ end

applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x)

function applychain(layers::AbstractVector, x)
for f in layers
x = f(x)
end
x
end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
Base.getindex(c::Chain{<:AbstractVector}, i::AbstractArray) = Chain(c.layers[i])

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
Expand All @@ -64,6 +77,7 @@ function Base.show(io::IO, c::Chain)
end
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]"))

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Expand Down
11 changes: 6 additions & 5 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ") ")
children = trainable(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
Expand All @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end
if indent == 0 # i.e. this is the outermost container
print(io, ")")
print(io, post)
_big_finale(io, obj)
else
println(io, " "^indent, "),")
println(io, " "^indent, post)
end
end
end
Expand Down Expand Up @@ -85,12 +86,12 @@ function _big_finale(io::IO, m)
noncnt = _childarray_sum(_->1, m) - length(ps)
if noncnt > 0
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
println(io, pars, " parameters,")
printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black)
print(io, bytes, ".")
else
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black)
printstyled(io, " "^18, "# Total: ", length(ps), " arrays, "; color=:light_black)
print(io, pars, " parameters, ", bytes, ".")
end
end
Expand Down
25 changes: 25 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import Flux: activations
@test m[1:2] == m

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name

@test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers
end

@testset "Activations" begin
Expand Down Expand Up @@ -274,6 +276,10 @@ end
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
@test Zygote.hessian_dual(summ1, [1,2,3]) Zygote.hessian_reverse(summ1, [1,2,3])

m1v = Chain([m1[1], m1[2]]) # vector of layers
@test Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_dual(summ1, [1,2,3])
@test_broken Zygote.hessian_dual(summ1v, [1,2,3]) Zygote.hessian_reverse(summ1v, [1,2,3])

# NNlib's softmax gradient writes in-place
m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax)
@test_broken Zygote.hessian_dual(summ2, [1,2,3]) Zygote.hessian_reverse(summ2, [1,2,3])
Expand All @@ -284,3 +290,22 @@ end
@test_broken Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
end

@testset "gradients of Chain{Vector}" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
m1v = Chain([m1[1], m1[2]])
@test sum(length, params(m1)) == sum(length, params(m1v))

x1 = randn(Float32,3,5)
@test m1(x1) m1v(x1)

y1 = rand(Bool,2,5)
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
@test g1[m1[1].weight] g1v[m1v[1].weight]
@test g1[m1[2].bias] g1v[m1v[2].bias]

@test Flux.destructure(m1)[1] Flux.destructure(m1v)[1]
z1 = rand(22);
@test Flux.destructure(m1)[2](z1)[1].weight Flux.destructure(m1v)[2](z1)[1].weight
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
end

0 comments on commit 9284bd7

Please sign in to comment.