Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the cache less often #39

Merged
merged 7 commits into from
Oct 20, 2022
Merged

Use the cache less often #39

merged 7 commits into from
Oct 20, 2022

Conversation

mcabbott
Copy link
Member

One of the things #32 proposes is to disable the use of the cache for some types, so that e.g. the number 4 appearing at two nodes is not regarded as a form of parameter sharing, just a co-incidence. This PR wants to make that change alone.

But what exactly is the right rule here? If (x = [1,2,3], y = 4) appears twice, then the shared [1,2,3] should be cached, but I think the 4 should still not be.

m1 = (a = (x = [1,2,3], y = 4), b = (x = [1,2,3], y = 4))
m1.a.x !== m1.b.x  # distinct arrays, hence not tied
m1.a.y === m1.b.y  # these are isbits, should not be tied

shared = [1,2,3]
m2 = (a = (x = shared, y = 4), b = (x = shared, y = 4))
m2.a.x === m2.b.x  # these really are tied
m2.a === m2.b  # but these should not be, since they would tie 4 === 4

m3 = (a = [[1,2,3], 4], b = [[1,2,3], 4])
m4 = (a = [shared, 4], b = [shared, 4])  # should still not tie 4 === 4
m4.a !== m4.b  # but now the container won't tempt you

This PR thinks that the right test is to only use cache on leaf nodes. #32 tested instead !isbits(x) && ismutable(x), which will also work on these examples. Where they differ is on an immutable container enclosing a mutable array:

shared = NamedDimsArray([1,2,3], :tri)  # a Functors-unaware array wrapper struct
shared = TiedArray(SA[1,2,3], Ref(0))   # the idea we had for marking StaticArrays as tied

Right now this uses the exclude keyword not the fixed isleaf. I think that makes sense but haven't thought too hard.

@ToucheSir
Copy link
Member

I thought about only considering leaves for caching, but one hiccup is when a mutable struct is used as a shared non-leaf node. Think Flux.BatchNorm. We could say that non-leaf nodes will always be untied, but that needs to be a) decided and b) documented with a prominent warning banner.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 12, 2022

It seems that by making your type non-leaf, you have opted in to having Functors traverse into it. What's gained by not doing so? Can you clarify what problem traversing it will cause?

My NamedDimsArray example is my complaint about ismutable, distilled. Could be OffsetArray, too. This ought to work just like an Array.

Since TiedArray is more explicitly about Functors, it could overload things.

@ToucheSir
Copy link
Member

Traversing it won't cause any issues, the question is how it should be re-assembled post traversal. Specifically if you run into the same mutable non-leaf multiple times, whether it should be untied as part of the reassembly process (vs fetched from the cache on subsequent occurrences).

The question is whether we're ok with breaking the following behaviour: given x′ = fmap(identity, x), any nodes in x that are === should also be so in x′. Maybe this is a reasonable thing to not care about, but if so it ought to be documented.

My NamedDimsArray example is my complaint about ismutable, distilled. Could be OffsetArray, too. This ought to work just like an Array.

My thought for this was to lean on parent for array wrappers. viz.

function usecache(x::AbstractArray)
  p = parent(x)
  p !== x && return usecache(p) # alt. typeof(p) !== typeof(x), etc. if we wanted type stability or to avoid potentially expensive `===` methods.
  return ismutable(x) # fallback 
end

usecache(x::Array) = true
...

@mcabbott
Copy link
Member Author

I guess " any nodes in x that are === should also be so in x′." is a simple-to-explain policy. Would be nice if the policy for how often f will be called was as simple. Will think a bit.

It seems a bit fragile to depend on the right parent methods existing (in one case) and not existing (in the other), since nothing requires them.

@ToucheSir
Copy link
Member

What wrapper types don't expose a parent method? Worst case they are considered not cacheable due to the fallback in https://github.com/JuliaLang/julia/blob/master/base/abstractarray.jl#L1398.

Would be nice if the policy for how often f will be called was as simple. Will think a bit.

Aye, this is much easier in a purer functional language where you're only traversing over trees. Sometimes I'm tempted to try representing the object graph as an actual digraph for this reason, but that's a little far off the deep end :P

@mcabbott mcabbott force-pushed the usecache branch 3 times, most recently from fe0fe80 to a1050b2 Compare September 24, 2022 14:43
@mcabbott
Copy link
Member Author

Now updated with a different rule:

  • non-leaf objects are cached only if they are mutable (as immutable structs can be perfectly reconstructed).
  • leaf objects are cached if they contain any mutable objects (by recursing fieldnames(x) etc, nothing special to arrays).

Weird cases:

  • If you use exclude to have f act on some non-leaf types, which are immutable but contain mutable objects, then it will not cache the result. Should it?
  • If an immutable struct has fields which are not children, it still won't be cached. However, the reconstruction will re-use the same original non-child, that's OK.

src/functor.jl Outdated
Comment on lines 45 to 50
# function _anymutable(x::T) where {T}
# ismutable(x) && return true
# fs = fieldnames(T)
# isempty(fs) && return false
# return any(f -> anymutable(getfield(x, f)), fs)
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fail to constant fold sometimes?

Otherwise LGTM. About the weird cases, we could argue it's more conservative to not cache in both. A false positive seems much worse than a false negative here IMO. Asking uses of a higher-level isleaf to take on additional responsibility for caching is also fine. Incidentally, this is why I think extracting out caching from Functors and making callbacks handle memoization themselves would be nice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails to be instant on surprisingly simple functions, I didn't try to dig into why:

julia> @btime fun_anymutable((x=(1,2), y=3))
  min 36.500 ns, mean 38.715 ns (1 allocation, 32 bytes)
false

julia> @btime gen_anymutable((x=(1,2), y=3))
  min 0.001 ns, mean 0.014 ns (0 allocations)
false

Perhaps more surprisingly, the generated one is also not free e.g. here:

julia> @btime fun_anymutable($(Metalhead.ResNet()))
  min 275.685 ns, mean 323.217 ns (9 allocations, 320 bytes)
true

julia> @btime gen_anymutable($(Metalhead.ResNet()))
  min 147.536 ns, mean 161.010 ns (1 allocation, 32 bytes)
true

That contains Chain([...]) which... should just stop the recursion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smaller example, in which the number of layers seems to matter:

julia> model = Chain(
         Conv((3, 3), 1 => 16),                # 160 parameters
         Conv((3, 3), 16 => 16),               # 2_320 parameters
         Conv((3, 3), 16 => 32),               # 4_640 parameters
         Conv((3, 3), 32 => 32),               # 9_248 parameters
         Conv((3, 3), 32 => 64),               # 18_496 parameters
         Conv((3, 3), 64 => 64),               # 36_928 parameters
         Dense(16384 => 10),                   # 163_850 parameters
       );

julia> @btime fun_anymutable($model)
  min 327.851 ns, mean 448.404 ns (10 allocations, 3.17 KiB)
true

julia> @btime gen_anymutable($model)
  min 215.463 ns, mean 238.700 ns (8 allocations, 608 bytes)
true

julia> model = Chain(
         Conv((3, 3), 1 => 16),                # 160 parameters
         Conv((3, 3), 16 => 16),               # 2_320 parameters
         # Conv((3, 3), 16 => 32),               # 4_640 parameters
         # Conv((3, 3), 32 => 32),               # 9_248 parameters
         # Conv((3, 3), 32 => 64),               # 18_496 parameters
         Conv((3, 3), 64 => 64),               # 36_928 parameters
         Dense(16384 => 10),                   # 163_850 parameters
       );

julia> @btime fun_anymutable($model)
  min 344.818 ns, mean 391.967 ns (10 allocations, 1.75 KiB)
true

julia> @btime gen_anymutable($model)
  min 0.001 ns, mean 0.014 ns (0 allocations)
true

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the Metalhead example at least, that one allocation is coming from https://github.com/FluxML/Metalhead.jl/blob/7827ca6ec4ef7c5e07d04cd6d84a1a3b11289dc0/src/convnets/resnets/resnet.jl#L17. For the longer Chain, Cthulhu tells me

┌ Info: Inference didn't cache this call information because of imprecise analysis due to recursion:
└ Cthulhu nevertheless is trying to descend into it for further inspection.

If I add a guard against the possible missing from any in gen_anymutable and assert the return value like so:

@generated function gen_anymutable(x::T) where {T}
  ismutabletype(T) && return true
  fs = fieldnames(T)
  isempty(fs) && return false
  subs =  [:(gen_anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
  return :(coalesce(|($(subs...)), false)::Bool)
end

That eliminates all but 6 of the allocations. I believe these correspond to the 6 Conv layers because the check on the Dense layer appears to be fully const folded (why only the Dense? Not sure).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. Just the ::Bool seems to be enough, and should be safe I think.

Weirdly it is instant for 5 and 7 conv layers, only exactly 6 causes it to fail & take 100ns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is absolutely bizarre. It also works for 6 Conv layers if I remove the final Dense and up to at least 32 with/without. Granted, this is on nightly—I couldn't get close to your timings on 1.8.2 IIRC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, less bizarre. Trying to simplify things a bit, it looks like the first call is always taking a hit here, but subsequent calls are fine.

using BenchmarkTools

struct Conv{N,M,F,A,V}
  σ::F
  weight::A
  bias::V
  stride::NTuple{N,Int}
  pad::NTuple{M,Int}
  dilation::NTuple{N,Int}
  groups::Int
end

struct Dense{F, M, B}
  weight::M
  bias::B
  σ::F
end

@generated function anymutable(x::T) where {T}
  ismutabletype(T) && return true
  fs = fieldnames(T)
  isempty(fs) && return false
  subs =  [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
  return :(|($(subs...))::Bool)
end

function test()
  for N in (5, 6, 7)
    @info N
    layers = ntuple(_ -> Conv(identity, ones(1), ones(1), (1,), (1,), (1,), 1), N)
    layers = (layers..., Dense(ones(1), ones(1), identity))
    @btime anymutable($layers)
  end
end

test()

Perhaps that has something to do with the generated function?

Copy link
Member Author

@mcabbott mcabbott Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird.

I don't suggest doing this, but this does seem to compile away:

julia> Base.@assume_effects :total function fun_anymutable3(x::T) where {T}
          ismutable(x) && return true
          fs = fieldnames(T)
          isempty(fs) && return false
          return any(f -> fun_anymutable3(getfield(x, f)), fs)::Bool
       end
fun_anymutable3 (generic function with 1 method)

julia> function test_3()
         for N in (5, 6, 7)
           @info N
           layers = ntuple(_ -> Conv(identity, ones(1), ones(1), (1,), (1,), (1,), 1), N)
           layers = (layers..., Dense(ones(1), ones(1), identity))
           @btime fun_anymutable3($layers)
         end
       end
test_3 (generic function with 1 method)

julia> test_3()
[ Info: 5
  min 0.083 ns, mean 0.185 ns (0 allocations)
[ Info: 6
  min 0.083 ns, mean 0.208 ns (0 allocations)
[ Info: 7
  min 0.083 ns, mean 0.229 ns (0 allocations)

julia> VERSION
v"1.9.0-DEV.1528"

(Edit -- inserted results)

Copy link
Member

@ToucheSir ToucheSir Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that still allocates (more, as a matter of fact) for me on nightly and 1.8. At least performance is consistent though.

[ Info: 5
  176.931 ns (7 allocations: 1.05 KiB)
[ Info: 6
  177.480 ns (7 allocations: 1.23 KiB)
[ Info: 7
  188.679 ns (7 allocations: 1.34 KiB)

julia> VERSION
v"1.9.0-DEV.1547"

src/functor.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member Author

Am going to merge this so that master has the new behaviour, but won't rush to tag it.

The code here will be changed by #43, but the tests may (I think) survive.

@mcabbott mcabbott merged commit 981c866 into FluxML:master Oct 20, 2022
@mcabbott mcabbott deleted the usecache branch October 20, 2022 19:23
darsnack added a commit to darsnack/Functors.jl that referenced this pull request Oct 31, 2022
darsnack added a commit to darsnack/Functors.jl that referenced this pull request Oct 31, 2022
darsnack added a commit to darsnack/Functors.jl that referenced this pull request Oct 31, 2022
@mcabbott mcabbott added this to the v0.4 milestone Nov 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants