-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction #806
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release-0.35 #806 +/- ##
================================================
+ Coverage 84.43% 84.60% +0.16%
================================================
Files 34 34
Lines 3823 3832 +9
================================================
+ Hits 3228 3242 +14
+ Misses 595 590 -5 ☔ View full report in Codecov by Sentry. |
5aafaf0
to
5be363f
Compare
Pull Request Test Coverage Report for Build 13417916562Details
💛 - Coveralls |
567e087
to
8e22c05
Compare
0b236bf
to
8de4742
Compare
aa54555
to
7cb38f3
Compare
OK, there are some failing Mooncake 1.10 tests which I had re-enabled (they were previously blanket disabled), but things pass on 1.11. The benchmarks have been updated (see next comment). The performance depends on how exactly we implement the DynamicPPL.jl/src/logdensityfunction.jl Lines 156 to 163 in 0f247e9
But as explained in the DI docs:
Which explains why there's a 3-4x slowdown for compiled ReverseDiff. We can get around this by instead constructing a closure, and avoiding using the DI.value_and_gradient(Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x) But the DI docs write
Closures are apparently not good for (at least) Enzyme: tpapp/LogDensityProblemsAD.jl#29 (comment) And on the benchmarks below we can see that the closure slows Mooncake down by ~ 20%. |
Benchmarks (updated)
Benchmarks do not include gradient preparation, only the calculation. The next 2 columns,
Code**New version** -- run on this PRusing Test, DynamicPPL, ADTypes, LogDensityProblems, Chairmarks, StatsBase, Random, Printf
import ForwardDiff
import ReverseDiff
import Mooncake
cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
vi = DynamicPPL.VarInfo(Xoshiro(468), m)
f = DynamicPPL.LogDensityFunction(m, vi)
x = vi[:]
for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
ldfwg = DynamicPPL.LogDensityFunctionWithGrad(f, adtype)
t = @be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
push!(cmarks, (m.f, adtype, t))
end
end
println("model,adtype,us_new")
for cmark in cmarks
time = @sprintf("%.2f", median(cmark[3]).time * 1e6)
println("$(cmark[1]),$(cmark[2]),$(time)")
end **Old version** -- run on `release-0.35` branch.using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Chairmarks, StatsBase, Random, Printf
import ForwardDiff
import ReverseDiff
import Mooncake
import DifferentiationInterface
cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
vi = DynamicPPL.VarInfo(Xoshiro(468), m)
f = DynamicPPL.LogDensityFunction(m, vi)
x = vi[:]
for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
adg = DynamicPPL._make_ad_gradient(adtype, f)
t = @be LogDensityProblems.logdensity_and_gradient($adg, $x)
push!(cmarks, (m.f, adtype, t))
end
end
println("model,adtype,us_old")
for cmark in cmarks
time = @sprintf("%.2f", median(cmark[3]).time * 1e6)
println("$(cmark[1]),$(cmark[2]),$(time)")
end |
It seems to me that for ForwardDiff, Mooncake, and non-compiled ReverseDiff we could potentially be good to go as long as we insert some backend-dependent code to decide whether to use a Fix1 closure or a constant. But I'm still a bit unsure about the 20% slowdown for compiled ReverseDiff. The odd thing is that I genuinely can't see where this slowdown is coming from. It looks to me that both LogDensityProblemsAD (here) and DifferentiationInterface (here) use the same code. @willtebbutt, or do you have any idea? Pinging @gdalle too: for context, this is the PR that would let us use DI throughout Turing, as you suggested in TuringLang/Turing.jl#2187 :) |
43ed59b
to
5b05ad3
Compare
Very cool, thanks @penelopeysm ! I'll take a look at the code. |
@gdalle, sure! Here is a slimmed down version of the benchmark code in the prev comment: LogDensityProblemsAD - run on master branch using DynamicPPL, Distributions, Chairmarks, ADTypes, Random, LogDensityProblems
import ReverseDiff
@model f() = x ~ Normal()
model = f()
vi = DynamicPPL.VarInfo(Xoshiro(468), model)
x = vi[:] # parameters
ldf = DynamicPPL.LogDensityFunction(model, vi)
adg = DynamicPPL._make_ad_gradient(AutoReverseDiff(; compile=true), ldf) # this is an ADgradient
@be LogDensityProblems.logdensity_and_gradient($adg, $x)
#= I got this:
julia> @be LogDensityProblems.logdensity_and_gradient($adg, $x)
Benchmark: 3507 samples with 20 evaluations
min 1.238 μs (3 allocs: 96 bytes)
median 1.310 μs (3 allocs: 96 bytes)
mean 1.329 μs (3 allocs: 96 bytes)
max 3.152 μs (3 allocs: 96 bytes)
=# DifferentiationInterface - run on current PR using DynamicPPL, Distributions, Chairmarks, ADTypes, Random, LogDensityProblems
import ReverseDiff
@model f() = x ~ Normal()
model = f()
vi = DynamicPPL.VarInfo(Xoshiro(468), model)
x = vi[:] # parameters
ldf = DynamicPPL.LogDensityFunction(model, vi)
ldfwg = DynamicPPL.LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
@be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
#= I got this:
julia> @be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
Benchmark: 3275 samples with 12 evaluations
min 2.055 μs (20 allocs: 736 bytes)
median 2.191 μs (20 allocs: 736 bytes)
mean 2.357 μs (20 allocs: 736 bytes, 0.03% gc time)
max 507.670 μs (20 allocs: 736 bytes, 98.53% gc time)
=# |
I did manage to find a little bit. Some of it comes from the line where Some of it comes from our side, specifically this: DynamicPPL.jl/src/logdensityfunction.jl Line 205 in 5b05ad3
If we simplify all of that by (1) removing the call to - y = f(x) # TODO: ReverseDiff#251
- result = DiffResult(y, (grad,))
+ S = eltype(x)
+ result = DiffResult(zero(S), (similar(x, S),))
if compile
result = gradient!(result, prep.tape, x)
else
result = gradient!(result, f, x, prep.config)
end
- y = DR.value(result)
- grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
- return y, grad
+ return DR.value(result), DR.gradient(result) then I get this, which is closer to the old performance. julia> @be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
Benchmark: 3075 samples with 19 evaluations
min 1.421 μs (7 allocs: 272 bytes)
median 1.491 μs (7 allocs: 272 bytes)
mean 1.598 μs (7 allocs: 272 bytes, 0.03% gc time)
max 227.597 μs (7 allocs: 272 bytes, 98.07% gc time) I would assume though that for more complicated models, the only part that scales is the calculation of the primal. If the reason why we do that is for correctness (?) JuliaDiff/ReverseDiff.jl#251 then maybe the performance hit is actually perfectly justified? |
Any idea whether that one test is supposed to fail so ungracefully? |
@gdalle No, not at all, and trying to debug failures on x86 has often led to not much to show for it but a huge waste of time, so we've generally decided not to test x86 on CI any more. (TuringLang/Turing.jl#2486) I'm sorry I haven't paid this any attention in the last two days! (Life stuff.) But I'll get round to it later this week :) |
Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
6909b27
to
8f8018a
Compare
8f8018a
to
566257e
Compare
current benchmarks
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@penelopeysm I jus ask that you make the code that you used to generate the benchmarks available to me so I can figure out why Mooncake has such uninspiring performance on these examples 😂
636168a
to
04f640d
Compare
Will and I chatted about this ^ on Slack, but for the sake of having this be open to all, here is the updated benchmark code. It differs from the previous ones in that this measures the time taken for the primal, and is only meant to work on the code on this branch / PR: using DynamicPPL, ADTypes, LogDensityProblems, Chairmarks, StatsBase, Random, Printf
import ForwardDiff
import ReverseDiff
import Mooncake
cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
vi = DynamicPPL.VarInfo(Xoshiro(468), m)
x = vi[:]
for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
f = DynamicPPL.LogDensityFunction(m, vi; adtype=adtype)
t = @be LogDensityProblems.logdensity_and_gradient($f, $x)
push!(cmarks, (m.f, adtype, t))
end
end
for cmark in cmarks
time = @sprintf("%.2f", median(cmark[3]).time * 1e6)
println("$(cmark[1]),$(cmark[2]),$(time)")
end |
Is this it? Can I tell my mom that DI is in Turing now? |
Technically we need to release a new version of DynamicPPL, and then make Turing compatible with that, but otherwise yes 😄 |
Happy to ping you when that happens too if you'd like |
Note that this opens some fun avenues for experimentation. For instance, you can now use symbolic backends like |
sick |
Overview
This PR removes LogDensityProblemsAD.jl as a dependency of DynamicPPL.
In place of constructing a
LogDensityProblemsAD.ADgradient
struct, and calculatinglogdensity_and_gradient
on that, we now hand-roll our AD type and its preparation intoLogDensityFunction
.If
LogDensityFunction
is invoked without an adtype, the behaviour is identical to before, i.e. you can calculate the logp but not its gradient.If
LogDensityFunction
is invoked with an adtype, then gradient preparation is carried out using DifferentiationInterface.jl, and the resulting gradient can be directly calculated.See HISTORY.md for more info.
Motivation
The benefits of this PR are several-fold (suffice it to say I'm very excited by this!):
LogDensityProblemsAD has a fair amount of custom code for different backends. For some backends (e.g. Mooncake), it delegates directly to DifferentiationInterface; for others (e.g. ReverseDiff) it contains code that can be very similar to DifferentiationInterface but can differ in subtle ways. Going straight to DI allows us to reduce the number of code paths we need to investigate when something goes wrong with AD. This is directly beneficial for implementing AD testing functionality (see Add AD testing utilities #799).
DifferentiationInterface is more actively maintained than LogDensityProblemsAD as well as many backends (e.g. ForwardDiff and ReverseDiff) and if bugs pop up with specific backends, it is more likely that DifferentiationInterface will be able to contain workarounds for them.
Unifying two structs into one allows for code that is simpler to understand in upstream packages such as Turing and can help us get rid of code that looks like this Clean up LogDensityFunctions interface code +
setADtype
Turing.jl#2473.Not going via
ADgradient
also affords us more control over specific behaviour. For example, this PR contains a line that emits a warning if a user attempts to use an AD backend that is not explicitly supported by Turing. This is trivially achieved inside the inner constructor ofLogDensityFunction
. Previously, to avoid method ambiguities, one would need to write something like:and make sure that all upstream code only ever called
_make_ad_gradient
instead of directly creating anADgradient
, which is unreliable.More explicitly codifying the adtype as a field of the 'model' object (technically a LogDensityFunction) will allow for better separation between model and sampler, see Add
getadtype
function to AbstractSampler interface AbstractMCMC.jl#158 (comment).This paves the way for potential further upstream simplifications to Turing's Optimisation module
(or even outright improvements by hooking into DifferentiationInterface).Performance benchmarks
The time taken to evaluate
logdensity_and_gradient
on all the demo models has been benchmarked before/after this PR, and can be found at: #806 (comment)