-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed vs Distrax #7
Comments
Hi @adam-hartshorne thanks for your report, I profile a little on CPU (m1 pro chip), and result in 3.54ms (fenbux) vs 3.57ms (distrax), and then I just digged into jaxpr in your code: make_jaxpr(jit_febux_test)(mean, sd, y_k), make_jaxpr(jit_distrax_test)(mean, sd, y_k) ![]() ![]() I think the problem is that the implment of import timeit
from jax import jit, make_jaxpr
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax
key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)
mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))
def febux_test(mean, sd, y_k):
return logpdf(Normal(mean=mean, sd=sd), y_k).sum()
def distrax_test(mean, sd, y_k):
return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()
jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)
%timeit -r 10 jit_febux_test(mean, sd, y_k).block_until_ready()
%timeit -r 10 jit_distrax_test(mean, sd, y_k).block_until_ready() And keep in mind,fenbux treats their distribution parameters as pytrees, so you can see in fenbux's jaxpr , an extra function ![]() and ![]() I'll open a PR to modify And, to compare speed with other libraries such as Finally, thanks again for reminding me some functions are not optimized enough to make |
#8 optimize the implement of |
I have run the following MVE versus Distrax (https://github.com/google-deepmind/distrax) and your library doesn't seem to be as fast. I am running this using jax 0.4.23, cuda 12.2, python 3.10 on a GeForce 4090.
It might be worth looking into why.
Febux Test Time: 0.10123697502422146
Distrax Test Time: 0.08472020699991845
The text was updated successfully, but these errors were encountered: