Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction #806

Merged
merged 11 commits into from
Feb 19, 2025

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 10, 2025

Overview

This PR removes LogDensityProblemsAD.jl as a dependency of DynamicPPL.

In place of constructing a LogDensityProblemsAD.ADgradient struct, and calculating logdensity_and_gradient on that, we now hand-roll our AD type and its preparation into LogDensityFunction.

If LogDensityFunction is invoked without an adtype, the behaviour is identical to before, i.e. you can calculate the logp but not its gradient.

If LogDensityFunction is invoked with an adtype, then gradient preparation is carried out using DifferentiationInterface.jl, and the resulting gradient can be directly calculated.

See HISTORY.md for more info.

Motivation

The benefits of this PR are several-fold (suffice it to say I'm very excited by this!):

  1. LogDensityProblemsAD has a fair amount of custom code for different backends. For some backends (e.g. Mooncake), it delegates directly to DifferentiationInterface; for others (e.g. ReverseDiff) it contains code that can be very similar to DifferentiationInterface but can differ in subtle ways. Going straight to DI allows us to reduce the number of code paths we need to investigate when something goes wrong with AD. This is directly beneficial for implementing AD testing functionality (see Add AD testing utilities #799).

  2. DifferentiationInterface is more actively maintained than LogDensityProblemsAD as well as many backends (e.g. ForwardDiff and ReverseDiff) and if bugs pop up with specific backends, it is more likely that DifferentiationInterface will be able to contain workarounds for them.

  3. Unifying two structs into one allows for code that is simpler to understand in upstream packages such as Turing and can help us get rid of code that looks like this Clean up LogDensityFunctions interface code + setADtype Turing.jl#2473.

  4. Not going via ADgradient also affords us more control over specific behaviour. For example, this PR contains a line that emits a warning if a user attempts to use an AD backend that is not explicitly supported by Turing. This is trivially achieved inside the inner constructor of LogDensityFunction. Previously, to avoid method ambiguities, one would need to write something like:

    function _make_ad_gradient(adtype, ldf)
        is_supported(adtype) || @warn "..."
        return ADgradient(adtype, ldf)
    end

    and make sure that all upstream code only ever called _make_ad_gradient instead of directly creating an ADgradient, which is unreliable.

  5. More explicitly codifying the adtype as a field of the 'model' object (technically a LogDensityFunction) will allow for better separation between model and sampler, see Add getadtype function to AbstractSampler interface AbstractMCMC.jl#158 (comment).

  6. This paves the way for potential further upstream simplifications to Turing's Optimisation module (or even outright improvements by hooking into DifferentiationInterface).

Performance benchmarks

The time taken to evaluate logdensity_and_gradient on all the demo models has been benchmarked before/after this PR, and can be found at: #806 (comment)

@penelopeysm penelopeysm changed the base branch from master to release-0.35 February 10, 2025 13:19
Copy link

codecov bot commented Feb 10, 2025

Codecov Report

Attention: Patch coverage is 90.38462% with 5 lines in your changes missing coverage. Please review.

Project coverage is 84.60%. Comparing base (f5e84f4) to head (04f640d).
Report is 1 commits behind head on release-0.35.

Files with missing lines Patch % Lines
ext/DynamicPPLForwardDiffExt.jl 70.00% 3 Missing ⚠️
src/contexts.jl 0.00% 1 Missing ⚠️
src/logdensityfunction.jl 97.56% 1 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff                @@
##           release-0.35     #806      +/-   ##
================================================
+ Coverage         84.43%   84.60%   +0.16%     
================================================
  Files                34       34              
  Lines              3823     3832       +9     
================================================
+ Hits               3228     3242      +14     
+ Misses              595      590       -5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@TuringLang TuringLang deleted a comment from github-actions bot Feb 10, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 10, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 10, 2025
@coveralls
Copy link

coveralls commented Feb 10, 2025

Pull Request Test Coverage Report for Build 13417916562

Details

  • 29 of 51 (56.86%) changed or added relevant lines in 3 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage increased (+0.2%) to 84.692%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/contexts.jl 0 1 0.0%
ext/DynamicPPLForwardDiffExt.jl 7 10 70.0%
src/logdensityfunction.jl 22 40 55.0%
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 56.82%
Totals Coverage Status
Change from base Build 13396312771: 0.2%
Covered Lines: 3242
Relevant Lines: 3828

💛 - Coveralls

@TuringLang TuringLang deleted a comment from github-actions bot Feb 13, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 13, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 13, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 13, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 13, 2025
@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 14, 2025

OK, there are some failing Mooncake 1.10 tests which I had re-enabled (they were previously blanket disabled), but things pass on 1.11.

The benchmarks have been updated (see next comment). The performance depends on how exactly we implement the logdensity_with_gradient method. On the current latest commit, I used DI.Constant to pass the model:

function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunctionWithGrad, x::AbstractVector
)
x = map(identity, x) # Concretise type
return DI.value_and_gradient(
_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)
)
end

But as explained in the DI docs:

Whenever contexts are provided, tape recording is deactivated in all cases, because otherwise the context values would be hardcoded into a tape.

Which explains why there's a 3-4x slowdown for compiled ReverseDiff.

We can get around this by instead constructing a closure, and avoiding using the Constant context. This is benchmarked below too.

DI.value_and_gradient(Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x)

But the DI docs write

Contexts can be useful if you have a function y = f(x, a, b, c, ...) or f!(y, x, a, b, c, ...) and you want derivatives of y with respect to x only. Another option would be creating a closure, but that is sometimes undesirable.

Closures are apparently not good for (at least) Enzyme:

tpapp/LogDensityProblemsAD.jl#29 (comment)
JuliaDiff/DifferentiationInterface.jl#311

And on the benchmarks below we can see that the closure slows Mooncake down by ~ 20%.

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 14, 2025

Benchmarks (updated)

  • old: time (µs) on release-0.35 branch, which uses LogDensityProblemsAD
  • new_constant: time (µs) with value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))
  • new_closure: time (µs) with value_and_gradient(Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x)

Benchmarks do not include gradient preparation, only the calculation.

The next 2 columns, new_constant/old and new_closure/old, are ratios of new vs old. Then finally I included means across all models for a given backend.

model                                       adtype                          old    new_constant  new_closure  new_constant/old  new_closure/old  mean(new_constant/old)  mean(new_closure/old)
demo_dot_assume_dot_observe                 AutoForwardDiff()               1.54   1.63          1.26         1.06              0.82             1.03                    1.05
demo_assume_index_observe                   AutoForwardDiff()               1.12   1.12          1.29         1.00              1.15                                     
demo_assume_multivariate_observe            AutoForwardDiff()               1.01   1.05          1.04         1.04              1.03                                     
demo_dot_assume_observe_index               AutoForwardDiff()               1.46   1.59          1.59         1.09              1.09                                     
demo_assume_dot_observe                     AutoForwardDiff()               0.73   0.75          0.77         1.03              1.05                                     
demo_assume_multivariate_observe_literal    AutoForwardDiff()               1.01   1.04          1.08         1.03              1.07                                     
demo_dot_assume_observe_index_literal       AutoForwardDiff()               1.52   1.54          1.6          1.01              1.05                                     
demo_assume_dot_observe_literal             AutoForwardDiff()               0.74   0.75          0.82         1.01              1.11                                     
demo_assume_observe_literal                 AutoForwardDiff()               0.73   0.74          0.76         1.01              1.04                                     
demo_assume_submodel_observe_index_literal  AutoForwardDiff()               1.81   1.74          1.89         0.96              1.04                                     
demo_dot_assume_observe_submodel            AutoForwardDiff()               2.23   2.32          2.38         1.04              1.07                                     
demo_dot_assume_dot_observe_matrix          AutoForwardDiff()               1.62   1.69          1.69         1.04              1.04                                     
demo_dot_assume_matrix_dot_observe_matrix   AutoForwardDiff()               1.46   1.57          1.56         1.08              1.07                                     
demo_assume_matrix_dot_observe_matrix       AutoForwardDiff()               1.04   1.12          1.15         1.08              1.11                                     
demo_dot_assume_dot_observe                 AutoMooncake{Nothing}(nothing)  13.27  13.44         16.83        1.01              1.27             1.02                    1.20
demo_assume_index_observe                   AutoMooncake{Nothing}(nothing)  7.33   7.58          8.74         1.03              1.19                                     
demo_assume_multivariate_observe            AutoMooncake{Nothing}(nothing)  10.42  10.6          13.85        1.02              1.33                                     
demo_dot_assume_observe_index               AutoMooncake{Nothing}(nothing)  9.26   9.74          11           1.05              1.19                                     
demo_assume_dot_observe                     AutoMooncake{Nothing}(nothing)  10.21  9.96          13.73        0.98              1.34                                     
demo_assume_multivariate_observe_literal    AutoMooncake{Nothing}(nothing)  6.5    6.77          7.62         1.04              1.17                                     
demo_dot_assume_observe_index_literal       AutoMooncake{Nothing}(nothing)  8.78   9.08          10.17        1.03              1.16                                     
demo_assume_dot_observe_literal             AutoMooncake{Nothing}(nothing)  5.35   5.56          6.59         1.04              1.23                                     
demo_assume_observe_literal                 AutoMooncake{Nothing}(nothing)  5.43   6.19          6.5          1.14              1.20                                     
demo_assume_submodel_observe_index_literal  AutoMooncake{Nothing}(nothing)  9.22   9.4           10.29        1.02              1.12                                     
demo_dot_assume_observe_submodel            AutoMooncake{Nothing}(nothing)  14.17  13.1          15.17        0.92              1.07                                     
demo_dot_assume_dot_observe_matrix          AutoMooncake{Nothing}(nothing)  11.44  11            12.96        0.96              1.13                                     
demo_dot_assume_matrix_dot_observe_matrix   AutoMooncake{Nothing}(nothing)  14.96  15.04         17.46        1.01              1.17                                     
demo_assume_matrix_dot_observe_matrix       AutoMooncake{Nothing}(nothing)  10.98  11.27         14.17        1.03              1.29                                     
demo_dot_assume_dot_observe                 AutoReverseDiff()               19.25  22.13         17.62        1.15              0.92             1.12                    0.94
demo_assume_index_observe                   AutoReverseDiff()               21.67  23.71         19.75        1.09              0.91                                     
demo_assume_multivariate_observe            AutoReverseDiff()               18.96  20.29         17.79        1.07              0.94                                     
demo_dot_assume_observe_index               AutoReverseDiff()               19.79  22.88         18.5         1.16              0.93                                     
demo_assume_dot_observe                     AutoReverseDiff()               15.42  16.92         14.21        1.10              0.92                                     
demo_assume_multivariate_observe_literal    AutoReverseDiff()               19.08  20.71         17.92        1.09              0.94                                     
demo_dot_assume_observe_index_literal       AutoReverseDiff()               20.21  22.88         18.12        1.13              0.90                                     
demo_assume_dot_observe_literal             AutoReverseDiff()               15.54  17            14.02        1.09              0.90                                     
demo_assume_observe_literal                 AutoReverseDiff()               15.79  16.92         14.17        1.07              0.90                                     
demo_assume_submodel_observe_index_literal  AutoReverseDiff()               20.46  23.62         18.96        1.15              0.93                                     
demo_dot_assume_observe_submodel            AutoReverseDiff()               20.38  24.67         19.04        1.21              0.93                                     
demo_dot_assume_dot_observe_matrix          AutoReverseDiff()               19.25  22.5          17.88        1.17              0.93                                     
demo_dot_assume_matrix_dot_observe_matrix   AutoReverseDiff()               22.75  25.17         27.25        1.11              1.20                                     
demo_assume_matrix_dot_observe_matrix       AutoReverseDiff()               19.46  21.38         18.63        1.10              0.96                                     
demo_dot_assume_dot_observe                 AutoReverseDiff(compile=true)   6.25   21.92         7.67         3.51              1.23             3.43                    1.19
demo_assume_index_observe                   AutoReverseDiff(compile=true)   6.99   23.88         8.5          3.42              1.22                                     
demo_assume_multivariate_observe            AutoReverseDiff(compile=true)   5.86   20.21         6.48         3.45              1.11                                     
demo_dot_assume_observe_index               AutoReverseDiff(compile=true)   6.78   23.13         8.15         3.41              1.20                                     
demo_assume_dot_observe                     AutoReverseDiff(compile=true)   5.13   16.88         5.88         3.29              1.15                                     
demo_assume_multivariate_observe_literal    AutoReverseDiff(compile=true)   5.87   20.42         6.49         3.48              1.11                                     
demo_dot_assume_observe_index_literal       AutoReverseDiff(compile=true)   7.12   22.83         7.89         3.21              1.11                                     
demo_assume_dot_observe_literal             AutoReverseDiff(compile=true)   5.36   17.04         6.13         3.18              1.14                                     
demo_assume_observe_literal                 AutoReverseDiff(compile=true)   5.53   16.79         6.05         3.04              1.09                                     
demo_assume_submodel_observe_index_literal  AutoReverseDiff(compile=true)   6.66   23.5          8.19         3.53              1.23                                     
demo_dot_assume_observe_submodel            AutoReverseDiff(compile=true)   6.24   24.54         8.13         3.93              1.30                                     
demo_dot_assume_dot_observe_matrix          AutoReverseDiff(compile=true)   6.45   22.67         7.5          3.51              1.16                                     
demo_dot_assume_matrix_dot_observe_matrix   AutoReverseDiff(compile=true)   7.51   25.92         10.62        3.45              1.41                                     
demo_assume_matrix_dot_observe_matrix       AutoReverseDiff(compile=true)   6      21.29         6.91         3.55              1.15                                     

Code

**New version** -- run on this PR
using Test, DynamicPPL, ADTypes, LogDensityProblems, Chairmarks, StatsBase, Random, Printf
import ForwardDiff
import ReverseDiff
import Mooncake

cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
    vi = DynamicPPL.VarInfo(Xoshiro(468), m)
    f = DynamicPPL.LogDensityFunction(m, vi)
    x = vi[:]
    for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
        ldfwg = DynamicPPL.LogDensityFunctionWithGrad(f, adtype)
        t = @be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
        push!(cmarks, (m.f, adtype, t))
    end
end
println("model,adtype,us_new")
for cmark in cmarks
    time = @sprintf("%.2f", median(cmark[3]).time * 1e6)
    println("$(cmark[1]),$(cmark[2]),$(time)")
end
**Old version** -- run on `release-0.35` branch.
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Chairmarks, StatsBase, Random, Printf
import ForwardDiff
import ReverseDiff
import Mooncake
import DifferentiationInterface

cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
    vi = DynamicPPL.VarInfo(Xoshiro(468), m)
    f = DynamicPPL.LogDensityFunction(m, vi)
    x = vi[:]
    for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
        adg = DynamicPPL._make_ad_gradient(adtype, f)
        t = @be LogDensityProblems.logdensity_and_gradient($adg, $x)
        push!(cmarks, (m.f, adtype, t))
    end
end
println("model,adtype,us_old")
for cmark in cmarks
    time = @sprintf("%.2f", median(cmark[3]).time * 1e6)
    println("$(cmark[1]),$(cmark[2]),$(time)")
end

@penelopeysm
Copy link
Member Author

It seems to me that for ForwardDiff, Mooncake, and non-compiled ReverseDiff we could potentially be good to go as long as we insert some backend-dependent code to decide whether to use a Fix1 closure or a constant. But I'm still a bit unsure about the 20% slowdown for compiled ReverseDiff.

The odd thing is that I genuinely can't see where this slowdown is coming from. It looks to me that both LogDensityProblemsAD (here) and DifferentiationInterface (here) use the same code.

@willtebbutt, or do you have any idea? Pinging @gdalle too: for context, this is the PR that would let us use DI throughout Turing, as you suggested in TuringLang/Turing.jl#2187 :)

@TuringLang TuringLang deleted a comment from github-actions bot Feb 14, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 14, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 14, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 14, 2025
@penelopeysm penelopeysm force-pushed the py/no-ldp-ad branch 2 times, most recently from 43ed59b to 5b05ad3 Compare February 14, 2025 11:50
@gdalle
Copy link

gdalle commented Feb 14, 2025

Very cool, thanks @penelopeysm ! I'll take a look at the code.
Could you give me an MWE to profile DI overhead? I could try to run it on this branch and on main to figure out where the slowdown comes from.

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 14, 2025

@gdalle, sure! Here is a slimmed down version of the benchmark code in the prev comment:

LogDensityProblemsAD - run on master branch

using DynamicPPL, Distributions, Chairmarks, ADTypes, Random, LogDensityProblems
import ReverseDiff
@model f() = x ~ Normal()
model = f()
vi = DynamicPPL.VarInfo(Xoshiro(468), model)
x = vi[:]  # parameters
ldf = DynamicPPL.LogDensityFunction(model, vi)
adg = DynamicPPL._make_ad_gradient(AutoReverseDiff(; compile=true), ldf)  # this is an ADgradient
@be LogDensityProblems.logdensity_and_gradient($adg, $x)

#= I got this:
julia> @be LogDensityProblems.logdensity_and_gradient($adg, $x)
Benchmark: 3507 samples with 20 evaluations
 min    1.238 μs (3 allocs: 96 bytes)
 median 1.310 μs (3 allocs: 96 bytes)
 mean   1.329 μs (3 allocs: 96 bytes)
 max    3.152 μs (3 allocs: 96 bytes)
=#

DifferentiationInterface - run on current PR

using DynamicPPL, Distributions, Chairmarks, ADTypes, Random, LogDensityProblems
import ReverseDiff
@model f() = x ~ Normal()
model = f()
vi = DynamicPPL.VarInfo(Xoshiro(468), model)
x = vi[:]  # parameters
ldf = DynamicPPL.LogDensityFunction(model, vi)
ldfwg = DynamicPPL.LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
@be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)

#= I got this:
julia> @be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
Benchmark: 3275 samples with 12 evaluations
 min    2.055 μs (20 allocs: 736 bytes)
 median 2.191 μs (20 allocs: 736 bytes)
 mean   2.357 μs (20 allocs: 736 bytes, 0.03% gc time)
 max    507.670 μs (20 allocs: 736 bytes, 98.53% gc time)
=#

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 14, 2025

I did manage to find a little bit. Some of it comes from the line where y = f(x) is recomputed, and some of it comes from the copyto! call at line 98 too:

https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/58605dc2c91922c20cbefd24197bd6cd809a2383/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl#L87-L100

Some of it comes from our side, specifically this:

x = map(identity, x) # Concretise type

If we simplify all of that by (1) removing the call to map(identity, ..) in DynamicPPL, and (2) modifying DI.value_and_gradient! to do this instead (which seems very unsafe, but it does parallel the LogDensityProblemsAD code more closely)

-   y = f(x)  # TODO: ReverseDiff#251
-   result = DiffResult(y, (grad,))
+   S = eltype(x)
+   result = DiffResult(zero(S), (similar(x, S),))
    if compile
        result = gradient!(result, prep.tape, x)
    else
        result = gradient!(result, f, x, prep.config)
    end
-   y = DR.value(result)
-   grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
-   return y, grad
+   return DR.value(result), DR.gradient(result)

then I get this, which is closer to the old performance.

julia> @be LogDensityProblems.logdensity_and_gradient($ldfwg, $x)
Benchmark: 3075 samples with 19 evaluations
 min    1.421 μs (7 allocs: 272 bytes)
 median 1.491 μs (7 allocs: 272 bytes)
 mean   1.598 μs (7 allocs: 272 bytes, 0.03% gc time)
 max    227.597 μs (7 allocs: 272 bytes, 98.07% gc time)

I would assume though that for more complicated models, the only part that scales is the calculation of the primal. If the reason why we do that is for correctness (?) JuliaDiff/ReverseDiff.jl#251 then maybe the performance hit is actually perfectly justified?

@gdalle
Copy link

gdalle commented Feb 18, 2025

Any idea whether that one test is supposed to fail so ungracefully?

@penelopeysm
Copy link
Member Author

@gdalle No, not at all, and trying to debug failures on x86 has often led to not much to show for it but a huge waste of time, so we've generally decided not to test x86 on CI any more. (TuringLang/Turing.jl#2486)

I'm sorry I haven't paid this any attention in the last two days! (Life stuff.) But I'll get round to it later this week :)

@penelopeysm
Copy link
Member Author

current benchmarks

model                                       adtype                          old    new    new/old  mean(new/old)  stdev(new/old)
demo_dot_assume_observe                     AutoForwardDiff()               1.19   1.11   0.933    1.053          0.117
demo_assume_index_observe                   AutoForwardDiff()               1.22   1.19   0.975                   
demo_assume_multivariate_observe            AutoForwardDiff()               1.06   1.08   1.019                   
demo_dot_assume_observe_index               AutoForwardDiff()               1.14   1.14   1.000                   
demo_assume_dot_observe                     AutoForwardDiff()               0.75   0.9    1.200                   
demo_assume_multivariate_observe_literal    AutoForwardDiff()               1.1    1.1    1.000                   
demo_dot_assume_observe_index_literal       AutoForwardDiff()               1.13   1.22   1.080                   
demo_assume_dot_observe_literal             AutoForwardDiff()               0.76   0.95   1.250                   
demo_assume_observe_literal                 AutoForwardDiff()               0.77   0.9    1.169                   
demo_assume_submodel_observe_index_literal  AutoForwardDiff()               1.45   1.26   0.869                   
demo_dot_assume_observe_submodel            AutoForwardDiff()               1.76   1.89   1.074                   
demo_dot_assume_observe_matrix_index        AutoForwardDiff()               1.23   1.16   0.943                   
demo_assume_matrix_observe_matrix_index     AutoForwardDiff()               0.94   1.11   1.181                   
demo_dot_assume_observe                     AutoMooncake{Nothing}(nothing)  5.73   5.65   0.986    1.015          0.043
demo_assume_index_observe                   AutoMooncake{Nothing}(nothing)  5.42   5.38   0.993                   
demo_assume_multivariate_observe            AutoMooncake{Nothing}(nothing)  4.76   4.98   1.046                   
demo_dot_assume_observe_index               AutoMooncake{Nothing}(nothing)  5.18   5.19   1.002                   
demo_assume_dot_observe                     AutoMooncake{Nothing}(nothing)  3.91   3.71   0.949                   
demo_assume_multivariate_observe_literal    AutoMooncake{Nothing}(nothing)  4.83   5.27   1.091                   
demo_dot_assume_observe_index_literal       AutoMooncake{Nothing}(nothing)  5.03   5.25   1.044                   
demo_assume_dot_observe_literal             AutoMooncake{Nothing}(nothing)  4.03   3.96   0.983                   
demo_assume_observe_literal                 AutoMooncake{Nothing}(nothing)  3.68   3.84   1.043                   
demo_assume_submodel_observe_index_literal  AutoMooncake{Nothing}(nothing)  5.43   5.8    1.068                   
demo_dot_assume_observe_submodel            AutoMooncake{Nothing}(nothing)  8.06   8.24   1.022                   
demo_dot_assume_observe_matrix_index        AutoMooncake{Nothing}(nothing)  5.69   5.77   1.014                   
demo_assume_matrix_observe_matrix_index     AutoMooncake{Nothing}(nothing)  5.83   5.56   0.954                   
demo_dot_assume_observe                     AutoReverseDiff()               22.21  20.08  0.904    0.955          0.040
demo_assume_index_observe                   AutoReverseDiff()               24.29  21.25  0.875                   
demo_assume_multivariate_observe            AutoReverseDiff()               19.92  19.33  0.970                   
demo_dot_assume_observe_index               AutoReverseDiff()               22.96  22.71  0.989                   
demo_assume_dot_observe                     AutoReverseDiff()               16.13  15.38  0.954                   
demo_assume_multivariate_observe_literal    AutoReverseDiff()               20.04  19.46  0.971                   
demo_dot_assume_observe_index_literal       AutoReverseDiff()               22.92  22.75  0.993                   
demo_assume_dot_observe_literal             AutoReverseDiff()               16     15.71  0.982                   
demo_assume_observe_literal                 AutoReverseDiff()               16.29  15.83  0.972                   
demo_assume_submodel_observe_index_literal  AutoReverseDiff()               23.5   23.25  0.989                   
demo_dot_assume_observe_submodel            AutoReverseDiff()               23.96  21.54  0.899                   
demo_dot_assume_observe_matrix_index        AutoReverseDiff()               23.25  21.58  0.928                   
demo_assume_matrix_observe_matrix_index     AutoReverseDiff()               20.42  20.17  0.988                   
demo_dot_assume_observe                     AutoReverseDiff(compile=true)   6.27   5.98   0.954    1.077          0.079
demo_assume_index_observe                   AutoReverseDiff(compile=true)   6.55   6.46   0.986                   
demo_assume_multivariate_observe            AutoReverseDiff(compile=true)   4.96   5.31   1.071                   
demo_dot_assume_observe_index               AutoReverseDiff(compile=true)   6.52   7.47   1.146                   
demo_assume_dot_observe                     AutoReverseDiff(compile=true)   4.58   5.13   1.120                   
demo_assume_multivariate_observe_literal    AutoReverseDiff(compile=true)   5.37   5.34   0.994                   
demo_dot_assume_observe_index_literal       AutoReverseDiff(compile=true)   6.46   7.21   1.116                   
demo_assume_dot_observe_literal             AutoReverseDiff(compile=true)   4.6    5.44   1.183                   
demo_assume_observe_literal                 AutoReverseDiff(compile=true)   5.08   5.28   1.039                   
demo_assume_submodel_observe_index_literal  AutoReverseDiff(compile=true)   6.45   7.49   1.161                   
demo_dot_assume_observe_submodel            AutoReverseDiff(compile=true)   6.26   6.06   0.968                   
demo_dot_assume_observe_matrix_index        AutoReverseDiff(compile=true)   5.94   6.64   1.118                   
demo_assume_matrix_observe_matrix_index     AutoReverseDiff(compile=true)   5.14   5.86   1.140                   

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@penelopeysm I jus ask that you make the code that you used to generate the benchmarks available to me so I can figure out why Mooncake has such uninspiring performance on these examples 😂

@penelopeysm
Copy link
Member Author

Will and I chatted about this ^ on Slack, but for the sake of having this be open to all, here is the updated benchmark code. It differs from the previous ones in that this measures the time taken for the primal, and is only meant to work on the code on this branch / PR:

using DynamicPPL, ADTypes, LogDensityProblems, Chairmarks, StatsBase, Random, Printf
import ForwardDiff
import ReverseDiff
import Mooncake

cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
    vi = DynamicPPL.VarInfo(Xoshiro(468), m)
    x = vi[:]
    for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
        f = DynamicPPL.LogDensityFunction(m, vi; adtype=adtype)
        t = @be LogDensityProblems.logdensity_and_gradient($f, $x)
        push!(cmarks, (m.f, adtype, t))
    end
end
for cmark in cmarks
    time = @sprintf("%.2f", median(cmark[3]).time * 1e6)
    println("$(cmark[1]),$(cmark[2]),$(time)")
end

@penelopeysm penelopeysm merged commit 90c7b26 into release-0.35 Feb 19, 2025
17 of 18 checks passed
@penelopeysm penelopeysm deleted the py/no-ldp-ad branch February 19, 2025 18:36
@gdalle
Copy link

gdalle commented Feb 19, 2025

Is this it? Can I tell my mom that DI is in Turing now?

@penelopeysm
Copy link
Member Author

Technically we need to release a new version of DynamicPPL, and then make Turing compatible with that, but otherwise yes 😄

@penelopeysm
Copy link
Member Author

Happy to ping you when that happens too if you'd like

@gdalle
Copy link

gdalle commented Feb 19, 2025

Note that this opens some fun avenues for experimentation. For instance, you can now use symbolic backends like AutoFastDifferentiation() in probabilistic programming, see where that gets you

@torfjelde
Copy link
Member

Note that this opens some fun avenues for experimentation. For instance, you can now use symbolic backends like AutoFastDifferentiation() in probabilistic programming, see where that gets you

sick

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants