Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimization for DifferentiationInterfaceJAX #42

Open
gdalle opened this issue Oct 18, 2024 · 1 comment
Open

Performance optimization for DifferentiationInterfaceJAX #42

gdalle opened this issue Oct 18, 2024 · 1 comment

Comments

@gdalle
Copy link

gdalle commented Oct 18, 2024

Hi @pabloferz!

Following your kind invitation, here's a prototype of what I would like to achieve in DifferentiationInterfaceJAX.jl: call a function defined in Python (fp) on Julia arrays (xj) with minimal overhead. I'm curious if there is any faster way to do things?

using BenchmarkTools
using DLPack: share, from_dlpack
using PythonCall

jax = pyimport("jax")
jnp = pyimport("jax.numpy")

fp(xp) = jnp.sum(jnp.square(xp))

function fj(xj)
    xp = share(xj, jax.dlpack.from_dlpack)
    return fp(xp)
end

function fjp!(xp_scratch, xj_scratch, xj)
    # assume xp_scratch and xj_scratch are aliased
    copyto!(xj_scratch, xj)
    return fp(xp_scratch)
end

xj = Float32.(1:10^5)
xp = share(xj, jax.dlpack.from_dlpack)

xj_scratch = Vector{Float32}(undef, 10^5)
xp_scratch = share(xj_scratch, jax.dlpack.from_dlpack)

Benchmark results:

julia> @btime fj($xj)
  88.455 μs (56 allocations: 2.52 KiB)
Python: Array(3.3333832e+14, dtype=float32)

julia> @btime fp($xp)
  14.754 μs (22 allocations: 368 bytes)
Python: Array(3.3333832e+14, dtype=float32)

julia> @btime fjp!($xp_scratch, $xj_scratch, $xj)
  28.563 μs (22 allocations: 368 bytes)
Python: Array(3.3333832e+14, dtype=float32)
@gdalle gdalle changed the title Performance optimization for DifferentiationInterfaceJAX.jl Performance optimization for DifferentiationInterfaceJAX Oct 18, 2024
@gdalle
Copy link
Author

gdalle commented Oct 27, 2024

Here's a benchmark for the other half of the overhead: moving back from JAX tensors to Julia arrays:

using BenchmarkTools
using DLPack: share, from_dlpack
using PythonCall

jax = pyimport("jax")
jnp = pyimport("jax.numpy")

fp(xp) = jnp.square(xp)

function fj(xp)
    yp = fp(xp)
    yj = from_dlpack(yp)
    return yj
end

xj = Float32.(1:10^5)
xp = share(xj, jax.dlpack.from_dlpack)
julia> @btime fp($xp);
  37.098 μs (11 allocations: 184 bytes)

julia> @btime fj($xp);
  49.526 μs (30 allocations: 2.02 KiB)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant