Skip to content

Commit

Permalink
use foldl less often
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 2, 2022
1 parent 058d59d commit a74f86d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,17 @@ end

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

(c::Chain)(x) = foldl((y,f) -> f(y), (x, c.layers...))
function (c::Chain)(x)
if order() < 2
foldl((y,f) -> f(y), (x, c.layers...))
else
# This hand-written foldl causes high latency
applychain(Tuple(c.layers), x)
end
end

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Expand Down
28 changes: 28 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -890,3 +890,31 @@ function plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)

return patience(is_plateau, width)
end


"""
order()
Returns `1` inside a call to `Zygote.gradient`, `2` inside nested such calls.
# Examples
```jldoctest; setup = :(using Flux, Zygote)
julia> Flux.order()
0
julia> gradient(x -> (@show(Flux.order()); x^3), 1)
Flux.order() = 1
(3.0,)
julia> gradient(y -> gradient(x -> (@show(Flux.order()); x^3), y)[1], 1)
Flux.order() = 2
(6.0,)
julia> Zygote.hessian(x -> (@show(Flux.order()); x^3), 1) # uses ForwardDiff over Zygote
Flux.order() = 1
6
```
"""
order(::Val{n} = Val(0)) where {n} = n

Zygote.@adjoint order(::Val{n}) where {n} = order(Val(n+1)), Returns(nothing)

0 comments on commit a74f86d

Please sign in to comment.