diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f51f8911aa..556ab59b9c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -27,6 +27,9 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` + +For large models, there is a special type-unstable path which can reduce compilation +times. This can be used by supplying a vector of layers `Chain([layer, layer])`. """ struct Chain{T} layers::T @@ -36,6 +39,7 @@ struct Chain{T} isempty(kw) && return new{Tuple{}}(()) new{typeof(values(kw))}(values(kw)) end + Chain(xs::AbstractVector) = new{typeof(xs)}(xs) # unstable path, to help compile times end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, @@ -53,6 +57,13 @@ end applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x) +function applychain(layers::AbstractVector, x) + for f in layers + x = f(x) + end + x +end + Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) @@ -64,6 +75,7 @@ function Base.show(io::IO, c::Chain) end _show_layers(io, layers::Tuple) = join(io, layers, ", ") _show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +_show_layers(io, layers::AbstractVector) = (print(io, "["); _show_layers(io, Tuple(layers)); print(io, "]")) # This is a temporary and naive implementation # it might be replaced in the future for better performance diff --git a/src/layers/show.jl b/src/layers/show.jl index 791d2511ca..819ef82d75 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,11 +14,12 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) + pre, post = obj isa Chain{<:AbstractVector} ? ("[", "]") : ("", "") children = trainable(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") + println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(", pre) if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) # then we insert names -- can this be done more generically? for k in Base.keys(obj) @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end if indent == 0 # i.e. this is the outermost container - print(io, ")") + print(io, post, ")") _big_finale(io, obj) else - println(io, " "^indent, "),") + println(io, " "^indent, post, "),") end end end