From 02ca329e6989e12256dad92cf55394dbae7d5093 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 8 Aug 2022 12:00:45 -0700 Subject: [PATCH] Faster filldist() (#227) * fix testset_zygote_broken() define vars used by error() * logpdf(arraydist): use mapreduce * logpdf(filldist): use mapreduce * remove filldist(Zygote) from broken * improve mapreduce Co-authored-by: David Widmann * improve mapreduce() invocation Co-authored-by: David Widmann * tests: exclude Chernoff from Zygote filldist tests * simplify mapreduce -> sum Co-authored-by: David Widmann * explicitly broadcast since it looks like `mapreduce()` still allocates Co-authored-by: David Widmann * require ChainRulesTestUtils >= 1.9.2 some graident tests require test_approx(::Array{<:Array}, ::Zero) * _flat_logpdf(): explicit lazy broadcasting * filldist tests: enable Skellam * use product_distribution() to fix deprecation * eliminate unnecessary intermediate var * replace some anonymous funcs with Base.Fix1 * replace sum(lambda, zip(...)) with lazy broadcast * Update src/arraydist.jl * Update test/ad/distributions.jl Co-authored-by: David Widmann Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/arraydist.jl | 17 ++++++++--------- src/filldist.jl | 12 +++++------- test/Project.toml | 2 +- test/ad/distributions.jl | 4 +--- test/ad/utils.jl | 8 ++++++-- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/arraydist.jl b/src/arraydist.jl index f4d5cdca..28e9e2b4 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -21,15 +21,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) return MatrixOfUnivariate(dists) end function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) - # return sum(((d, xi),) -> logpdf(d, xi), zip(dist.dists, x)) - # Broadcasting here breaks Tracker for some reason - return sum(map(logpdf, dist.dists, x)) + # Lazy broadcast to avoid allocations and use pairwise summation + return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x))) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) @@ -52,16 +51,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution}) end function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - return sum(((di, xi),) -> logpdf(di, xi), zip(dist.dists, eachcol(x))) + return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x)))) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) - return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) + return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init) end diff --git a/src/filldist.jl b/src/filldist.jl index e4c63084..67b758a3 100644 --- a/src/filldist.jl +++ b/src/filldist.jl @@ -30,21 +30,19 @@ end function _flat_logpdf(dist, x) if toflatten(dist) f, args = flatten(dist) - return sum(f.(args..., x)) + # Lazy broadcast to avoid allocations and use pairwise summation + return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x))) else - return sum(map(x) do x - logpdf(dist, x) - end) + return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x))) end end function _flat_logpdf_mat(dist, x) if toflatten(dist) f, args = flatten(dist) - return vec(sum(f.(args..., x), dims = 1)) + return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1)) else - temp = map(x -> logpdf(dist, x), x) - return vec(sum(temp, dims = 1)) + return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1)) end end diff --git a/test/Project.toml b/test/Project.toml index 33d7bb7d..129e08bf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesCore = "1" -ChainRulesTestUtils = "1" +ChainRulesTestUtils = "1.9.2" Combinatorics = "1.0.2" Distributions = "0.25.15" FiniteDifferences = "0.11.3, 0.12" diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index ef5ee181..d3ce2eb5 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -408,9 +408,7 @@ # PoissonBinomial fails with Zygote # Matrix case does not work with Skellam: # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 - filldist_broken = if D <: Skellam - ((d.broken..., :Zygote, :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff)) - elseif D <: PoissonBinomial + filldist_broken = if D <: PoissonBinomial ((d.broken..., :Zygote), (d.broken..., :Zygote)) elseif D <: Chernoff # Zygote is not broken with `filldist` diff --git a/test/ad/utils.jl b/test/ad/utils.jl index ecf4a61a..a4dcd6e5 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -396,12 +396,16 @@ function testset_zygote(distspec, unpack_x_θ, args...; kwargs...) end end -function testset_zygote_broken(args...; kwargs...) +function testset_zygote_broken(distspec, args...; kwargs...) # don't show test errors - tests are known to be broken :) testset = suppress_stdout() do - testset_zygote(args...; kwargs...) + testset_zygote(distspec, args...; kwargs...) end + f = distspec.f + θ = distspec.θ + x = distspec.x + # change errors and fails to broken results, and count number of errors and fails efs = errors_to_broken!(testset)