Skip to content

Commit

Permalink
allow unstable Chain{Vector} too
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 11, 2022
1 parent 657e267 commit 300c4ad
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
12 changes: 12 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer, layer])`.
"""
struct Chain{T}
layers::T
Expand All @@ -36,6 +39,7 @@ struct Chain{T}
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
Chain(xs::AbstractVector) = new{typeof(xs)}(xs) # unstable path, to help compile times
end

@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Expand All @@ -53,6 +57,13 @@ end

applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x)

function applychain(layers::AbstractVector, x)
for f in layers
x = f(x)
end
x
end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
Expand All @@ -64,6 +75,7 @@ function Base.show(io::IO, c::Chain)
end
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
_show_layers(io, layers::AbstractVector) = (print(io, "["); _show_layers(io, Tuple(layers)); print(io, "]"))

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Expand Down
7 changes: 4 additions & 3 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("[", "]") : ("", "")
children = trainable(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(", pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
Expand All @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end
if indent == 0 # i.e. this is the outermost container
print(io, ")")
print(io, post, ")")
_big_finale(io, obj)
else
println(io, " "^indent, "),")
println(io, " "^indent, post, "),")
end
end
end
Expand Down

0 comments on commit 300c4ad

Please sign in to comment.