-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addressing performance issues with broadcasting #230
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy about the performance improvements but at the same time very unhappy about introducing additional type piracy in DistributionsAD. We have been removing more and more of the type piracy in DistributionsAD (with the ultimate goal of making the package completely obsolete at some point) and transferring fixes and distributions to other AD packages and Distributions. For instance, it would also be good to start adapting product_distribution
and ProductDistribution and improving its AD compatibility and performance in Distributions instead of the custom filldist
, arraydist
, and all the custom structs in DistributionsAD. The logpdf
type piracies in this package have also already caused problems (IIRC there is also at least one open issue).
|
||
make_logpdf_closure(::Type{D}) where {D} = (x, args...) -> logpdf(D(args...), x) | ||
|
||
function Distributions.logpdf(dist::Product{<:Any,D,<:StructArrays.StructArray}, x::AbstractVector{<:Real}) where {D} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Product
is deprecated.
I am of course in complete agreement with you in everything you say here, but do you think such a "hacky" approach would be accepted into Distributions? That is, do you foresee another solution in the near- or even medium-term future?:/ |
Could we add performance and AD improvements to |
I'm not completlely understanding what you mean here 😕 Do you mean we should make something like product_distribution(Broadcast.broadcasted(BernoulliLogit, logitp)) work and have a similar implementation as in this PR? |
Maybe something like (I know using Distributions: UnivariateDistribution, MultivariateDistribution, ValueSupport
struct LazyProduct{
S<:ValueSupport,
T<:UnivariateDistribution{S},
V,
} <: MultivariateDistribution{S}
v::V
function LazyProduct{S,T,V}(v::V) where {S<:ValueSupport,T<:UnivariateDistribution{S},V}
return new{S,T,V}(v)
end
end
function LazyProduct(
v::V
) where {S<:ValueSupport,T<:UnivariateDistribution{S},V<:Broadcast.Broadcasted{<:Any,<:Any,Type{T}}}
return LazyProduct{S, T, V}(v)
end
Base.length(d::LazyProduct) = length(d.v)
function Base.eltype(::Type{<:LazyProduct{S,T}}) where {S<:ValueSupport,
T<:UnivariateDistribution{S}}
return eltype(Broadcast.combine_eltypes(bc.f, bc.args))
end
Distributions._rand!(rng::Distributions.AbstractRNG, d::LazyProduct, x::AbstractVector{<:Real}) =
map!(Base.Fix1(rand, rng), x, d.v)
function Distributions._logpdf(d::LazyProduct, x::AbstractVector{<:Real})
bc = d.v
f = make_logpdf_closure(bc.f)
return sum(f.(x, bc.args...))
end This does the trick for ReverseDiff but Zygote complains (it tries to call |
Basically yes but for |
This is much slower for ReverseDiff vs. actually allocating, i.e. EDIT: In the particular example above, we're talking 10X slower without compilation and 4X slower with compilation. EDIT: Zygote handles |
IMO |
Fair 😕 Speaking of LazyArrays; I was just looking at maybe using this for the product distribution. Then it just becomes a matter of doing |
Just for the record, ReverseDiff supports this and all. It's just that |
That is, we can just change DistributionsAD.jl/src/DistributionsAD.jl Lines 91 to 96 in c30bcd9
to something like make_logpdf_closure(::Type{<:BroadcastVector{<:Any,Type{F}}}) where {F} = make_logpdf_closure(F)
function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
f = DistributionsAD.make_logpdf_closure(typeof(dist.v))
# TODO: Fix ReverseDiff performance on this.
return sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, dist.v.args...)))
end with make_logpdf_closure(::Type{D}) where {D} = (x, args...) -> logpdf(D(args...), x) |
Regarding ReverseDiff's issue with |
This seems like it's somewhat non-trivial to address 😕 The issue is that we want to allocate in the case we're using ReverseDiff since it allows us to vectorize using ForwardDiff.. |
Also, is julia> using BenchmarkTools
julia> v = rand(100);
julia> f(v) = sum(v .* v')
f (generic function with 1 method)
julia> g(v) = mapreduce(identity, +, Broadcast.broadcasted(*, v, v'))
g (generic function with 1 method)
julia> h(v) = mapreduce(identity, +, Base.Broadcast.materialize(Broadcast.broadcasted(*, v, v')))
h (generic function with 1 method)
julia> @benchmark f($v)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
Range (min … max): 5.846 μs … 401.951 μs ┊ GC (min … max): 0.00% … 95.71%
Time (median): 6.519 μs ┊ GC (median): 0.00%
Time (mean ± σ): 8.044 μs ± 15.636 μs ┊ GC (mean ± σ): 12.52% ± 6.39%
▃██▆▄▂▁ ▁▁▂ ▂
███████▇█▇██████▇▇▆▆▆▆▆▇▆▆▆▆▅▅▅▄▄▆▄▃▅▃▄▄▁▄▃▁▃▁▁▁▁▃▁▁▃▄▃▅▇▇▇ █
5.85 μs Histogram: log(frequency) by time 24.5 μs <
Memory estimate: 78.17 KiB, allocs estimate: 2.
julia> @benchmark g($v)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 25.326 μs … 105.242 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 25.986 μs ┊ GC (median): 0.00%
Time (mean ± σ): 27.306 μs ± 4.657 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▅█▄▂▂ ▁▃▁▂ ▃ ▃ ▂ ▁
█████▇████▅█▁█▃██▆█▁▄▁▁▃▁▄▁▃▃▁▁▁▁▁▁▁▁▄▄▄▄▆▅▅▄▆▅▅▅▄▆▄▆▅▅▄▄▄▄▄ █
25.3 μs Histogram: log(frequency) by time 52.6 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark h($v)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
Range (min … max): 5.058 μs … 309.694 μs ┊ GC (min … max): 0.00% … 91.66%
Time (median): 6.459 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.756 μs ± 14.452 μs ┊ GC (mean ± σ): 12.16% ± 6.39%
▇█▆▄▁ ▂
▅▄███████▆▄▄▃▆█▇▆▆▅▆▆▅▅▅▆▄▆▅▅▁▅▄▅▃▄▁▃▁▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▇▆▇ █
5.06 μs Histogram: log(frequency) by time 24.4 μs <
Memory estimate: 78.17 KiB, allocs estimate: 2. I.e. materializing is faster (at least in this scenario, also involving So seems to indicate if we want speed on |
IMO performance issues are bugs as well since I think usually you should be able to be as fast (or slow) as regular broadcasting by materializing. It would be better to support Broadcast.instantiate... in Distributions if possible than adding a dependency on LazyArrays there, I think. In the example it seems you did not instantiate the broadcasting object. That's necessary for good performance in general. |
But isn't this an issue of how |
Doesn't matter here: julia> using BenchmarkTools
julia> v = rand(100);
julia> f(v) = sum(v .* v')
f (generic function with 1 method)
julia> g(v) = mapreduce(identity, +, Broadcast.instantiate(Broadcast.broadcasted(*, v, v')))
g (generic function with 1 method)
julia> h(v) = mapreduce(identity, +, Base.Broadcast.materialize(Broadcast.instantiate(Broadcast.broadcasted(*, v, v'))))
h (generic function with 1 method)
julia> @benchmark f($v)
BenchmarkTools.Trial: 10000 samples with 7 evaluations.
Range (min … max): 4.927 μs … 64.068 μs ┊ GC (min … max): 0.00% … 65.85%
Time (median): 7.004 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.705 μs ± 3.979 μs ┊ GC (mean ± σ): 5.13% ± 8.93%
▄█▇▅▂▂▂▁ ▂
▅▄██████████▅▃▄▄▃▁▁▃▁▁▁▁▃▁▃▁▁▁▁▁▁▅▇▇▆▆▁▃▄▃▁▁▁▁▁▃▁▁▁▁▁▁▃▄▅▆ █
4.93 μs Histogram: log(frequency) by time 36.3 μs <
Memory estimate: 78.17 KiB, allocs estimate: 2.
julia> @benchmark g($v)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 13.625 μs … 60.730 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 14.411 μs ┊ GC (median): 0.00%
Time (mean ± σ): 14.894 μs ± 1.880 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▃ ▄▂ █▇ ▂▆▂ ▆▃ ▄▄ ▁▄▂ ▁▂ ▂
█▇▁▁▁▁▇██▁▁▁▁▁██▄▁▁▁▁███▃▁▁▁▁▆██▁▁▃▃▁▁██▆▁▁▄▃▄▅███▅▃▄▁▃▁▁██ █
13.6 μs Histogram: log(frequency) by time 16.7 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark h($v)
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
Range (min … max): 4.328 μs … 85.797 μs ┊ GC (min … max): 0.00% … 90.22%
Time (median): 6.613 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.486 μs ± 5.520 μs ┊ GC (mean ± σ): 7.09% ± 8.95%
▁█▆▂▂▂▁ ▁
▆███████▇▆▅▄▃▄▃▃▁▁▁▁▁▁▅▇█▄▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅ █
4.33 μs Histogram: log(frequency) by time 52.2 μs <
Memory estimate: 78.17 KiB, allocs estimate: 2. |
In both cases the same pairwise algorithm is used: https://github.com/JuliaLang/julia/blob/0371bf44bf6bfd6ee9fbfc32d478c2ff4c97b08b/base/reduce.jl#L426-L449 and https://github.com/JuliaLang/julia/blob/0371bf44bf6bfd6ee9fbfc32d478c2ff4c97b08b/base/reduce.jl#L251-L275 It just computes the required entries of the array or the broadcasting object when they are required. If there's an optimized way for computing the array (e.g. by hitting BLAS calls, I assume), then the allocations might be acceptable and still result in better performance - but the summation itself uses the same code in both cases. I just ran the following on Julia 1.8.5: julia> using BenchmarkTools
julia> A = vcat(1f0, fill(1f-8, 10^8));
julia> B = ones(Float32, length(A));
julia> f1(A, B) = reduce(+, A .* B);
julia> f2(A, B) = sum(A .* B);
julia> g1(A, B) = reduce(+, Broadcast.instantiate(Broadcast.broadcasted(*, A, B)));
julia> g2(A, B) = sum(Broadcast.instantiate(Broadcast.broadcasted(*, A, B)));
julia> h1(A, B) = sum(a * b for (a, b) in zip(A, B));
julia> h2(A, B) = sum(zip(A, B)) do (a, b)
return a * b
end;
julia> f1(A, B)
1.9999989f0
julia> f2(A, B)
1.9999989f0
julia> g1(A, B)
1.9999989f0
julia> g2(A, B)
1.9999989f0
julia> h1(A, B)
1.0f0
julia> h2(A, B)
1.0f0
julia> @btime f1($A, $B);
185.125 ms (2 allocations: 381.47 MiB)
julia> @btime f2($A, $B);
190.276 ms (2 allocations: 381.47 MiB)
julia> @btime g1($A, $B);
46.049 ms (0 allocations: 0 bytes)
julia> @btime g2($A, $B);
45.674 ms (0 allocations: 0 bytes)
julia> @btime h1($A, $B);
143.151 ms (0 allocations: 0 bytes)
julia> @btime h2($A, $B);
141.105 ms (0 allocations: 0 bytes) Generally, it does not matter if we use Regarding your benchmark: julia> using BenchmarkTools
julia> v = rand(100);
julia> g1(v) = sum(Broadcast.instantiate(Broadcast.broadcasted(*, v, v')));
julia> g2(v) = reduce(+, Broadcast.instantiate(Broadcast.broadcasted(*, v, v')));
julia> g3(v) = mapreduce(identity, +, Broadcast.instantiate(Broadcast.broadcasted(*, v, v')));
julia> h1(v) = sum(v .* v');
julia> h2(v) = reduce(+, v .* v');
julia> h3(v) = mapreduce(identity, +, v .* v');
julia> @btime g1($v);
16.961 μs (0 allocations: 0 bytes)
julia> @btime g2($v);
15.575 μs (0 allocations: 0 bytes)
julia> @btime g3($v);
15.670 μs (0 allocations: 0 bytes)
julia> @btime h1($v);
6.872 μs (2 allocations: 78.17 KiB)
julia> @btime h2($v);
7.173 μs (2 allocations: 78.17 KiB)
julia> @btime h3($v);
6.849 μs (2 allocations: 78.17 KiB)
julia> v = rand(1_000);
julia> @btime g1($v);
1.626 ms (0 allocations: 0 bytes)
julia> @btime g2($v);
1.469 ms (0 allocations: 0 bytes)
julia> @btime g3($v);
^[[ 1.469 ms (0 allocations: 0 bytes)
julia> @btime h1($v);
1.030 ms (2 allocations: 7.63 MiB)
julia> @btime h2($v);
1.031 ms (2 allocations: 7.63 MiB)
julia> @btime h3($v);
1.035 ms (2 allocations: 7.63 MiB)
julia> v = rand(10_000);
julia> @btime g1($v);
162.012 ms (0 allocations: 0 bytes)
julia> @btime g2($v);
160.191 ms (0 allocations: 0 bytes)
julia> @btime g3($v);
160.172 ms (0 allocations: 0 bytes)
julia> @btime h1($v);
303.014 ms (2 allocations: 762.94 MiB)
julia> @btime h2($v);
309.235 ms (2 allocations: 762.94 MiB)
julia> @btime h3($v);
316.549 ms (2 allocations: 762.94 MiB) |
Ah, interesting! But what do we do about the AD issues? In particular, how do we tell ReverseDiff to hit the faster path, i.e. not trace through the entire thing using |
Btw, related to all of this: JuliaArrays/LazyArrays.jl#232 |
Tracing should be fine if everything is lazy, shouldn't it? There's not much happening there in the forward pass. I would assume the problem is rather that for most functions (such as |
But tracing using |
In particular here, we're looking at a full trace through everything vs. a tape containing only two statements ( |
But that's another issue in ReverseDiff, isn't it? Similar to Zygote it should define adjoints mainly for |
In broadcasting using ForwardDiff is probably unavoidable (and Zygote and Tracker do the same) but IMO in general ReverseDiff uses |
Yep! I'm just saying that I'm personally not familiar enough with ReverseDiff to address this issue (and I'm worried it's going to take too much time before someone gets around to it 😕), and given the performance difference I'm not a big fan of just "leaving it" as is. We can always do something hacky like _inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D
function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
# TODO: Implement chain rule for `LazyArray` constructor to support Zygote.
f = DistributionsAD.make_closure(logpdf, _inner_constructor(typeof(dist.v)))
args = dist.v.args
return if ReverseDiff.istracked(args) || ReverseDiff.istracked(x)
sum(f.(x, args...))
else
sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))
end
end but that ain't particularly nice.. |
Closing in favour of #231 |
This PR introduces a "lazy" version of
arraydist
which allows us to hack around performance issues with broadcasting over constructors in ReverseDiff and Zygote. Based on "insights" from TuringLang/Turing.jl#1934.More specifically, in broadcasting:
Union{Real,Complex}
, hence we miss the fast branch.Normal
withNormalF(args...) = Normal(args...)
in a broadcast will actually help ReverseDiff (Performance regression for BernoulliLogit Turing.jl#1934 (comment)).Here's an example of the result of the "lazy" array dist:
I also thought maybe this is where LazyArrays.jl could be useful, but preliminary attempts didn't end up being fruitful so uncertain.