Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster arraydist with LazyArrays.jl #231

Closed
wants to merge 19 commits into from

Conversation

torfjelde
Copy link
Member

This PR is basically an accumulation of the discussion in TuringLang/Turing.jl#1934 and #230.

It's a hack to make reverse-mode AD packages that uses ForwardDiff for broadcasting much faster when used in combination with LazyArrays.jl.

Unfortunately, this requires a rather ugly hack that is make_closure (maybe there's a more elegant solution? @devmotion pls halp!), but it does buy us a whole lot of runtime.

)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return vec(sum(copy(logpdf.(dists, x)), dims = 1))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This even has a bug in it: dists isn't defined..

@torfjelde torfjelde requested a review from devmotion January 17, 2023 13:54
src/lazyarrays.jl Outdated Show resolved Hide resolved
src/common.jl Outdated
Comment on lines 88 to 96
# Notes
To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one
has a function `myfunc` then we need to define

```julia
make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x)
```

This can also be done using `DistributionsAD.@specialize_make_closure`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just missing type parameters in the definition below?

src/common.jl Outdated
Comment on lines 118 to 159


"""
has_specialized_make_closure(f, g)

Return `true` if there exists a specialized `make_closure(f, g)` implementation.
"""
has_specialized_make_closure(f, g) = false

# To go vroooom we need to specialize on the first argument, thus ensuring that
# a different closure is constructed for each method.
"""
@specialize_make_closure(f)

Define `make_closure` and `has_specialized_make_closure` for first first argument being `f`
and second argument being a type.
"""
macro specialize_make_closure(f)
return quote
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = (x, args...) -> $(esc(f))(D(args...), x)
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = true
end
end

"""
@specialize_make_closure(f, g)

Define `make_closure` and `has_specialized_make_closure` for first first argument being `f`
and second argument being `g`.
"""
macro specialize_make_closure(f, g)
return quote
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::typeof($(esc(g)))) = (x, args...) -> $(esc(f))($(esc(g))(args...), x)
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::typeof{$(esc(g))}) = true
end
end

@specialize_make_closure Distributions.pdf
@specialize_make_closure Distributions.logpdf
@specialize_make_closure Distributions.loglikelihood
@specialize_make_closure Distributions.cdf
@specialize_make_closure Distributions.logcdf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible to remove all of this code. Maybe type parameters are already sufficient. Or using a callable struct might help.

Comment on lines 3 to 6
# Necessary to make `BroadcastArray` work nicely with Zygote.
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...)
return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChainRules definitions for LazyArrays.jl should really not be part of DistributionsAD. The other type piracy is already bad but at least Distributions-specific. But these lines seem really inappropriate here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! See JuliaArrays/LazyArrays.jl#232

Think the plan is to make a glue-package for now.

Btw, this rrule is also not good since, for example, logpdf(arraydist(BroadcastArray(Normal, x)), data) will then be separated into two broadcast statements again, which is the opposite of what we want 😕

We could of course define adjoint rules for logpdf, etc. that specializes on the BroadcastArray scenario, but that is all non-ideal 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a glue package is a good long term solution. If one does not want to load ChainRulesCore for all users (even though I think it's loaded anyways in almost all realistic scenarios), a weak dependency seems the best way IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw got any good idea of how to define the constructor for BroadcastArray properly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This far I have:

function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...)
    function BroadcastArray_pullback::ChainRulesCore.Tangent)
        return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...)
    end
    return BroadcastArray(f, args...), BroadcastArray_pullback
end

function ChainRulesCore.rrule(
    config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
    ::typeof(Distributions.logpdf),
    dist::Distributions.Product{V,D,A}, x::AbstractVector{<:Real}
) where {V,D<:Distribution,A<:BroadcastArray}
    cl = DistributionsAD.Closure(logpdf, DistributionsAD._inner_constructor(typeof(dist.v)))
    y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...)
    z, dz = ChainRulesCore.rrule_via_ad(config, sum, y)

    f = dist.v.f
    function logpdf_adjoint...)
        # 1st argument is `sum` -> nothing.
        (_, sum_Δ...) = dz...)
        # 1st argument is `broadcast` -> nothing.
        # 2nd argument is `cl` -> `nothing`.
        # 3rd argument is `x` -> something.
        # Rest is `dist` arguments -> something
        (_, _, x_Δ, args_Δ...) = dy(sum_Δ...)
        # Construct the structural tangents.
        ba_tangent = ChainRulesCore.Tangent{A}(f=f, args=args_Δ)
        dist_tangent = ChainRulesCore.Tangent{typeof(dist)}(v=ba_tangent)

        return (ChainRulesCore.NoTangent(), dist_tangent, x_Δ)
    end

    return z, logpdf_adjoint
end

which does indeed to the trick but does not look great 😕

  1. I don't like how I have to call back into AD, but I don't atm see a way around that.
  2. I'm not too familiar with structural Tangent, and so I don't know if nesting them is a bad idea. For example I noticed that when we hit the pullback for BroadcastArray, I'm looking at a Tangent{Any} despite this not being the case in logpdf_adjoint (though this might just be Zygote doing something with it in the mean time?).
  3. This isn't full support for BroadcastArray since in most cases it won't receive a Tangent.

Any more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to look into it a bit this evening (currently a bit busy since we are working on a paper for ICML but it's in such an early stage I'm not sure we will make the deadline 😄). ProjectTo reminds me of this PR which I just happened to comment on this morning: JuliaArrays/FillArrays.jl#153 Maybe the code there could be helpful if you want to define ProjectTo for BroadcastArray (not sure if that was what you were asking about).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome:) Btw, really appreciate your responses over these past days; been incredibly helpful:) I'll be doing some workshop talks at a winter school on Turing next week and I would have loved to pin this down before that but we'll see 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw I just pushed the resulting impl so you can look at it properly when you have time. It does seem to work awfully well but it's very hacky 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A PhD student from Gothenburg actually told me about the winter school and thought it might be interesting for me based on my research interests. I would have liked to go to Norway since I have only been to Bergen so far but unfortunately I don't have time (and I assume I'm not the intended audience for a workshop about Turing 😛).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh that would have been awesome!:( Well, hope you make it at some other point. Haha, maybe not 😅

Comment on lines +65 to +67

lazyarray(f, x...) = BroadcastArray(f, x...)
export lazyarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we deprecate it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to!

src/common.jl Outdated
Comment on lines 116 to 117
make_closure(f, g) = (x, args...) -> f(g(args...), x)
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there are any possible performance/compiler benefits by not closing over the variables but to use (more-Julian) callable structs that capture f and g. In any case, I think you want

Suggested change
make_closure(f, g) = (x, args...) -> f(g(args...), x)
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x)
make_closure(f::F, g::G) where {F,G} = (x, args...) -> f(g(args...), x)
make_closure(f::F, ::Type{D}) where {F,D} = (x, args...) -> f(D(args...), x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I unfortuantely tried that but AFAIK it ends up with this issue of closing over a UnionAll type again, which is exactly what we're trying to avoid (because of the issues it's causing with some AD backends) 😕

I might have not done it correctly though.

But you suggestion I have tried, and it unfortunately doesn't have an affect. If you just look at the returned closures, they're all the same one 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like

struct Closure{F,G} end

Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()

for f in [pdf, logpdf, cdf, logcdf]
    @eval (::$(Closure){typeof($f),G})(x, args...) where {G} = $f(G(args...), x)
end

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()

and others avoid the UnionAll issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look at what I've done now 👍

Comment on lines +95 to +96
f = Base.issingletontype(F) ? F.instance : F
g = Base.issingletontype(G) ? G.instance : G
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a 100% certain on this. Need to think when I've had some sleep.

# 2nd argument is `cl` -> `nothing`.
# 3rd argument is `x` -> something.
# Rest is `dist` arguments -> something
(_, _, x_Δ, args_Δ...) = dy(sum_Δ...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm worried about: what if f is a closure containing variables to differentiate wrt. to? 😬

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so that should be addressed with the is_diff_safe above. We will now only hit this faster path if we're certain we don't need to take derivatives wrt. anything "in" the consturctor itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> is_diff_safe(typeof(Closure(logpdf, makef([1.0]))))
false
"""
@inline is_diff_safe(_) = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah basically, though I was looking at Zygote's _dual_purefun because ReverseDiff already hits the broadcast_forward, even without the custom adjoint I defined (ReverseDiff also doesn't support calling back into AD).

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a rather quick look. There's many things that should not be here (but maybe some them could be moved or fixed somewhere else). Specializations for LazyArrays could maybe be added to Distributions as weak dependencies (or maybe we just fix the use of Broadcasted). Nevertheless it might be OK for some time to have fixes here, but we should make sure to test everything.

src/common.jl Outdated Show resolved Hide resolved
Comment on lines +87 to +92
struct Closure{F,G} end

Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really what I had in mind. More something like Base.Fix1:

Suggested change
struct Closure{F,G} end
Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()
struct Closure{F,G}
f::F
g::G
end
Closure(f::F, g::G) where {F,G} = Closure{F,G}(f, g)
Closure(f::F, ::Type{G}) where {F,G} = Closure{F,Type{G}}(f, G)
Closure(::Type{F}, g::G) where {F,G} = Closure{Type{F},G}(G, g)
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{Type{F},Type{G}}(F, G)

Generally just storing the type of e.g. f is not sufficient: if f is e.g. a callable struct F does not provide enough information.

However, with fields the struct the performance with ReverseDiff is bad since then we hit https://github.com/JuliaDiff/ReverseDiff.jl/blob/d522508aa6fea16e9716607cdd27d63453bb61e6/src/derivatives/broadcast.jl#L27. This can be fixed by defining

ReverseDiff.mayhavetracked(c::Closure) = ReverseDiff.mayhavetracked(c.f) || ReverseDiff.mayhavetracked(c.g)

I wonder if we can just improve the heuristics in ReverseDiff use a similar check for structs/types with multiple fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I originally had fields but yeah this resulted in bad computation paths. Might be something that should be changed in the AD instead, true.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth pointing out that

  1. Closure is not something that's meant to be used heavily for arbitrary callables (e.g. the checks in the adjoint explicitly exclude the scenario where we have fields).
  2. Closure should not be used by the end-user.

Comment on lines +134 to +138
@generated function (closure::Closure{F,G})(x, args...) where {F,G}
f = Base.issingletontype(F) ? F.instance : F
g = Base.issingletontype(G) ? G.instance : G
return :($f($g(args...), x))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of @generated function if there is a different way to get good performance, they have too many limitations IMO.

If we want to keep Closure without fields (not sure, maybe it would be better to change the heuristics in ReverseDiff), then the following seems to work at least in the example above:

julia> struct Closure{F,G} end

julia> Closure(::F, ::G) where {F,G} = Closure{F,G}()

julia> Closure(::F, ::Type{G}) where {F,G} = Closure{F,Type{G}}()

julia> Closure(::Type{F}, ::G) where {F,G} = Closure{Type{F},G}()

julia> Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{Type{F},Type{G}}()

julia> (::Closure{F,G})(x, args...) where {F,G} = F.instance(G.instance(args...), x)

julia> (::Closure{F,Type{G}})(x, args...) where {F,G} = F.instance(G(args...), x)

julia> (::Closure{Type{F},G})(x, args...) where {F,G} = F(G.instance(args...), x)

julia> (::Closure{Type{F},Type{G}})(x, args...) where {F,G} = F(G(args...), x)

But somehow this version and the one in the PR seem all a bit hacky...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But somehow this version and the one in the PR seem all a bit hacky...

As mentioned before, I 100% agree with you. But this performance issue is literally the cause of several Slack and Discourse threads of people going "why is Turing so slow for this simple model?", and so IMO we should just get this fixed despite its hackiness and then we make it less hacky as we go + maybe improve ReverseDiff and Zygote.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and regarding the @generated, I'm happy to do away with it. I'll try your suggestion 👍

)
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once
# we've addressed performance issues in ReverseDiff.jl.
constructor = _inner_constructor(typeof(dist.v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure this will be problematic in some cases and break. It's not guaranteed that _inner_constructor returns a proper constructor. Something safer would be https://github.com/JuliaObjects/ConstructionBase.jl I assume.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple example:

julia> struct A{X,Y}
           x::X
           y::Y
           A(x::X, y::Y) where {X,Y} = new{X,Y}(x, y)
       end

julia> _constructor(::Type{D}) where {D} = D
_constructor (generic function with 1 method)

julia> x, y = 1, 2.0
(1, 2.0)

julia> a = A(x, y)
A{Int64, Float64}(1, 2.0)

julia> _constructor(typeof(a))(x, y)
ERROR: MethodError: no method matching A{Int64, Float64}(::Int64, ::Float64)
Stacktrace:
 [1] top-level scope
   @ REPL[31]:1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was actually using ConstructionBase locally for this before:) But I removed it because I figured this will only be used for a very simple subset of constructors, so uncertain if it's worth it. But I'll add it back again:)

Comment on lines +50 to +60
function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...)
function BroadcastArray_pullback(Δ::ChainRulesCore.Tangent)
return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...)
end
return BroadcastArray(f, args...), BroadcastArray_pullback
end

ChainRulesCore.ProjectTo(ba::BroadcastArray) = ProjectTo{typeof(ba)}((f=ba.f,))
function (p::ChainRulesCore.ProjectTo{BA})(args...) where {BA<:BroadcastArray}
return ChainRulesCore.Tangent{BA}(f=p.f, args=args)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this is needed. Feels like that's the default for Tangents anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so we can alos just close over the function f and construct the Tanget directly in the adjoint (in fact, this is what I did originally), but I thought maybe ProjectTo was the more "proper" way to do it. I can go back to directly constructing the Tangent though:)

src/lazyarrays.jl Outdated Show resolved Hide resolved

# If it's not safe to ignore the `constructor` in the pullback, then we fall back
# to the default implementation.
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use

Suggested change
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x)
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> logpdf(d, x), dist, x)

to avoid making any assumptions about how logpdf(dist, x) is implemented?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively maybe just return nothing?

Suggested change
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x)
is_diff_safe(constructor) || return nothing

(https://juliadiff.org/ChainRulesCore.jl/stable/ad_author/opt_out.html)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid making any assumptions about how logpdf(dist, x) is implemented?

But won't this hit the same rrule once you get to logpdf?

Alternatively maybe just return nothing?

I don't completely understand this. So opting out means it will fall back to the underlying AD? I thought it meant "nothing to differentiate here"

return ChainRulesCore.Tangent{BA}(f=p.f, args=args)
end

function ChainRulesCore.rrule(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW the annoying thing about these kinds of general rules is that it might break (and it happened to me multiple times) code that would have worked without rule and if one would just let the AD system perform its default differentiation. One can fix these issues though by using @opt_out (but that also has some problems...).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also just define this using ZygoteRules.@adjoint if that helps?

Unfortunately there's no way around this because we have to stop Zygote from trying to differentiate through the broadcasted constructor.

# 2nd argument is `cl` -> `nothing`.
# 3rd argument is `x` -> something.
# Rest is `dist` arguments -> something
(_, _, x_Δ, args_Δ...) = dy(sum_Δ...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

project_broadcastarray = ChainRulesCore.ProjectTo(dist.v)
function logpdf_adjoint(Δ...)
# 1st argument is `sum` -> nothing.
(_, sum_Δ...) = dz(Δ...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally you might have to deal with unthunk I assume.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But should the other pullbacks also deal with this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we only need to deal with unthunk in the pullback for the BroadcastArray constructor, no?

torfjelde and others added 2 commits January 19, 2023 00:05
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@torfjelde
Copy link
Member Author

I had a rather quick look. There's many things that should not be here (but maybe some them could be moved or fixed somewhere else). Specializations for LazyArrays could maybe be added to Distributions as weak dependencies (or maybe we just fix the use of Broadcasted). Nevertheless it might be OK for some time to have fixes here, but we should make sure to test everything.

I 100% agree with you, but realistically tackling these issues is going to take a lot of time and effort. In the mean time, I think this hacky approach will have to do 😕 This solves so many Slack and Discourse threads of people going "why is this simple model so slow in Turing.jl?"...

@svilupp
Copy link

svilupp commented Feb 6, 2023

I was wondering what is stopping this PR from being merged?
I've had this page in my open tabs for a few weeks now, anxiously awaiting the purple merged badge 😄

I was hoping to re-do my logistic regression benchmark with this PR once it's merged, but I jumped the gun here.

In short, the benefits are incredible and the best part is that it would be easy even for new users (it just needs an update in the tutorials).

Comment on lines +73 to +78
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
using .Zygote: Zygote
# HACK: Make Zygote (correctly) recognize that it should use `ForwardDiff` for broadcasting.
# See `is_diff_safe` for more information.
Zygote._dual_purefun(::Type{C}) where {C<:Closure} = is_diff_safe(C)
end

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make into a weak dep?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha yeah that should be done with basically everything in this repo. Or rather, most things should be moved to Distributions and added as a weak dep there to fix the type piracy issues.

Copy link

@ToucheSir ToucheSir Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally is_diff_safe would be a function ForwardDiff or one of its dependencies owns. That would avoid the need for a package extension or Requires block overriding an internal function like _dual_purefun.

@torfjelde
Copy link
Member Author

I was wondering what is stopping this PR from being merged?

I've had this page in my open tabs for a few weeks now, anxiously awaiting the purple merged badge smile

The approach taken here is quite hacky, and, tbh, it's unfortunate that we even have to do this. It's really just working around type-inference issues for calling a UnionAll.

So if we accept, we have to be on high-alert in case something breaks (due to its hackinesse), and ideally this should be fixed elsewhere so maybe we should spend some more time thinking to see if we can address this in a broader manner.

I was hoping to re-do my logistic regression benchmark with this PR once it's merged, but I jumped the gun here.

In short, the benefits are incredible and the best part is that it would be easy even for new users (it just needs an update in the tutorials).

Glad to see it's working though!:)

@svilupp
Copy link

svilupp commented Mar 9, 2023

I was wondering what is stopping this PR from being merged?

I've had this page in my open tabs for a few weeks now, anxiously awaiting the purple merged badge smile

The approach taken here is quite hacky, and, tbh, it's unfortunate that we even have to do this. It's really just working around type-inference issues for calling a UnionAll.

So if we accept, we have to be on high-alert in case something breaks (due to its hackinesse), and ideally this should be fixed elsewhere so maybe we should spend some more time thinking to see if we can address this in a broader manner.

I was hoping to re-do my logistic regression benchmark with this PR once it's merged, but I jumped the gun here.

In short, the benefits are incredible and the best part is that it would be easy even for new users (it just needs an update in the tutorials).

Glad to see it's working though!:)

Since this PR is unlikely to be merged, do we have tips/snippets that intermediate users could opt into in their code?

How would I recognize that the UnionAll is slowing me down (besides the slowness itself)?
I can see that you identified it with this call by running @code_warntype f.makeargs.f and looking for Body::Any.

However, it's not clear to me "where" to break into.
How would I set a breakpoint for it with Debugger?

using Debugger, Turing
# as per your example
@model function irt(y, i, p; I = maximum(i), P = maximum(p))
    theta ~ filldist(Normal(), P)
    beta ~ filldist(Normal(), I)
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y))

    return (; theta, beta)
end
model = irt(y, i, p); 
# what is the `some_func` at which we should break to see the context
@bp some_func
@run chn=sample(model, NUTS(), 100)

In the absence of this PR, what's the best way to overcome the AD taking the slow path?
My understanding of your thread was that it comes from broadcasting over Structs, so we want compiler to remove them.

Eg, define a wrapper

BernoulliLogitF(x) =BernoulliLogit(x)

# to be used like
Turing.@addlogprob! sum(logpdf.(BernoulliLogitF.(theta[p] - beta[i]), y))

# instead of this in Turing.@addlogprob! 
Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y))

and check as per the above point if you get Body::Any or a type

Or would you go about it differently?

EDIT: I found that a better/simpler workaround is to use BernoulliLogitF(x, c) = BernoulliLogit(x, c) as per the above

@yebai
Copy link
Member

yebai commented Apr 17, 2024

This should be ideally fixed by autodiff, e.g. Tapir and Enzyme.

@yebai yebai closed this Apr 17, 2024
@yebai yebai deleted the torfjelde/lazy-array-perf branch October 3, 2024 20:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants