You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following your kind invitation, here's a prototype of what I would like to achieve in DifferentiationInterfaceJAX.jl: call a function defined in Python (fp) on Julia arrays (xj) with minimal overhead. I'm curious if there is any faster way to do things?
using BenchmarkTools
using DLPack: share, from_dlpack
using PythonCall
jax =pyimport("jax")
jnp =pyimport("jax.numpy")
fp(xp) = jnp.sum(jnp.square(xp))
functionfj(xj)
xp =share(xj, jax.dlpack.from_dlpack)
returnfp(xp)
endfunctionfjp!(xp_scratch, xj_scratch, xj)
# assume xp_scratch and xj_scratch are aliasedcopyto!(xj_scratch, xj)
returnfp(xp_scratch)
end
xj =Float32.(1:10^5)
xp =share(xj, jax.dlpack.from_dlpack)
xj_scratch =Vector{Float32}(undef, 10^5)
xp_scratch =share(xj_scratch, jax.dlpack.from_dlpack)
The text was updated successfully, but these errors were encountered:
gdalle
changed the title
Performance optimization for DifferentiationInterfaceJAX.jl
Performance optimization for DifferentiationInterfaceJAX
Oct 18, 2024
Hi @pabloferz!
Following your kind invitation, here's a prototype of what I would like to achieve in DifferentiationInterfaceJAX.jl: call a function defined in Python (
fp
) on Julia arrays (xj
) with minimal overhead. I'm curious if there is any faster way to do things?Benchmark results:
The text was updated successfully, but these errors were encountered: