Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient with DI is 2x slower than native Mooncake #272

Closed
gdalle opened this issue Sep 27, 2024 · 2 comments
Closed

Gradient with DI is 2x slower than native Mooncake #272

gdalle opened this issue Sep 27, 2024 · 2 comments
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code

Comments

@gdalle
Copy link

gdalle commented Sep 27, 2024

Following the conversation on Slack around StaticArrays, I compared benchmark results between DI and native Mooncake, and there is a 2x slowdown I'm not sure I can explain. Any idea what could be causing this @willtebbutt?

using BenchmarkTools
using DifferentiationInterface
using Mooncake: Mooncake
using StaticArrays

# setup

struct Foo{T} <: FieldVector{4,T}
    field1::T
    field2::T
    field3::T
    field4::T
end

function func1(foo1::Foo, foo2::Foo, par1::Real, par2::Real)
    var1 = foo1.field1
    var2 = foo2.field1
    var3 = var1 * var2
    var4 = foo1.field3 + var3 * (foo1.field4 - foo1.field3)
    var5 = foo2.field3 + var3 * (foo2.field4 - foo2.field3)
    var6 = 0.5f0 * (foo1.field2 * var5 + var4 * foo2.field2)
    var7 = 0.5f0 * (foo1.field2 * (var4 - foo1.field3) + foo2.field2 * (var5 - foo2.field3))
    return (par1 * var7 + par2 * var6)
end

func1_merged(foo12::Tuple{Foo,Foo}, par1, par2) = func1(foo12[1], foo12[2], par1, par2)

function bench_function(foo1, foo2, par1, par2)
    @btime func1($foo1, $foo2, $par1, $par2)
    return nothing
end

function bench_mooncake_native(foo1, foo2, par1, par2)
    fargs = (func1, foo1, foo2, par1, par2)
    rule = Mooncake.build_rrule(fargs...)
    Mooncake.value_and_gradient!!(rule, fargs...)
    @btime Mooncake.value_and_gradient!!($rule, $fargs...)
    return nothing
end

function bench_mooncake_di(foo1, foo2, par1, par2)
    backend = AutoMooncake(; config=nothing)
    foo12 = (foo1, foo2)
    prep = prepare_pullback(
        func1_merged, backend, foo12, (1.0,), Constant(par1), Constant(par2)
    )
    @btime value_and_pullback(
        $func1_merged, $prep, $backend, $foo12, (1.0,), Constant($par1), Constant($par2)
    )
    return nothing
end

foo1 = Foo(1.0, 2.0, 3.0, 4.0)
foo2 = Foo(5.0, 6.0, 7.0, 8.0)
par1 = 2.7
par2 = 3.1
julia> bench_function(foo1, foo2, par1, par2)
  3.952 ns (0 allocations: 0 bytes)

julia> bench_mooncake_native(foo1, foo2, par1, par2)
  71.322 ns (0 allocations: 0 bytes)

julia> bench_mooncake_di(foo1, foo2, par1, par2)
  136.347 ns (0 allocations: 0 bytes)
@willtebbutt willtebbutt added the enhancement (performance) Would reduce the time it takes to run some bit of the code label Sep 28, 2024
@willtebbutt
Copy link
Member

Hmm I'm struggling to replicate. On my system, I see:

julia> bench_function(foo1, foo2, par1, par2)
  4.096 ns (0 allocations: 0 bytes)

julia> bench_mooncake_native(foo1, foo2, par1, par2)
  67.660 ns (0 allocations: 0 bytes)

julia> bench_mooncake_di(foo1, foo2, par1, par2)
  67.296 ns (0 allocations: 0 bytes)

My versions are:

  [6e4b80f9] BenchmarkTools v1.5.0
  [a0c0ee7d] DifferentiationInterface v0.6.1
  [da2b9cff] Mooncake v0.4.4
  [90137ffa] StaticArrays v1.9.7

Could you let me know if these are the same as yours?

@gdalle
Copy link
Author

gdalle commented Sep 30, 2024

With DI v0.6.1 and Mooncake v0.4.5 this is resolved indeed. It might have been on Mooncake v0.4.3 when I first ran the test.

@gdalle gdalle closed this as completed Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code
Projects
None yet
Development

No branches or pull requests

2 participants