Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed vs Distrax #7

Open
adam-hartshorne opened this issue Jan 16, 2024 · 2 comments
Open

Speed vs Distrax #7

adam-hartshorne opened this issue Jan 16, 2024 · 2 comments

Comments

@adam-hartshorne
Copy link

I have run the following MVE versus Distrax (https://github.com/google-deepmind/distrax) and your library doesn't seem to be as fast. I am running this using jax 0.4.23, cuda 12.2, python 3.10 on a GeForce 4090.

It might be worth looking into why.

import timeit

def setup_code():
    return '''
from jax import jit
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax

key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)

mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))

def febux_test(mean, sd, y_k):
    return logpdf(Normal(mean=mean, sd=sd), y_k).sum()

def distrax_test(mean, sd, y_k):
    return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()

jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)
'''


febux_time = timeit.timeit('jit_febux_test(mean, sd, y_k).block_until_ready()',
                           setup=setup_code(), number=1000)

# Timing distrax_test
distrax_time = timeit.timeit('jit_distrax_test(mean, sd, y_k).block_until_ready()',
                             setup=setup_code(), number=1000)

print("Febux Test Time:", febux_time)
print("Distrax Test Time:", distrax_time)

Febux Test Time: 0.10123697502422146
Distrax Test Time: 0.08472020699991845

@JiaYaobo
Copy link
Owner

JiaYaobo commented Jan 17, 2024

Hi @adam-hartshorne thanks for your report, I profile a little on CPU (m1 pro chip), and result in 3.54ms (fenbux) vs 3.57ms (distrax), and then I just digged into jaxpr in your code:

make_jaxpr(jit_febux_test)(mean, sd, y_k), make_jaxpr(jit_distrax_test)(mean, sd, y_k)
截屏2024-01-17 13 36 00 截屏2024-01-17 13 36 19

I think the problem is that the implment of normal_logpdf in dist_math.normal module is not well optimized, and result in more complex jaxpr, so I adjust normal_logpdf code to match distrax, and profile the program on a Titan RTX as below

import timeit
from jax import jit, make_jaxpr
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax

key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)

mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))

def febux_test(mean, sd, y_k):
    return logpdf(Normal(mean=mean, sd=sd), y_k).sum()

def distrax_test(mean, sd, y_k):
    return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()

jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)

%timeit -r 10 jit_febux_test(mean, sd, y_k).block_until_ready()
%timeit -r 10 jit_distrax_test(mean, sd, y_k).block_until_ready()

And results are:
截屏2024-01-17 14 51 42

And keep in mind,fenbux treats their distribution parameters as pytrees, so you can see in fenbux's jaxpr , an extra function tree_map_dist_at is called everytime. Now except this extra function, fenbux's jaxpr is now matching with distrax's as below:

截屏2024-01-17 14 54 46

and

截屏2024-01-17 14 55 09

I'll open a PR to modify normal_logpdf function soon.

And, to compare speed with other libraries such as tensorflow-probability or distrax, fenbux are expected to be always faster if you simply jit the methods of distributions like jit(dist.log_prob) (as I wrote in readme), and if you compare them wrapped in a function exactly like what you did, with simple array/tensor inputs as parameters, fenbux will exactly match the speed with these libraries under same jaxpr-level optimization !

Finally, thanks again for reminding me some functions are not optimized enough to make jaxpr simplest, I'll dig into these next version of fenbux, and profile performance on GPU. Does it make sense?

@JiaYaobo
Copy link
Owner

#8 optimize the implement of normal_logpdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants