Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax decompiler #13398

Open
PierrickPochelu opened this issue Nov 24, 2022 · 4 comments
Open

Jax decompiler #13398

PierrickPochelu opened this issue Nov 24, 2022 · 4 comments
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@PierrickPochelu
Copy link

PierrickPochelu commented Nov 24, 2022

A jax decompiler would take jaxpr code and produce a more readable Python code. Even if some information about the original function is lost (obfuscated code) like variable names being lost. Decompilers are important tool for reverse-engineering.

Here the illustration of the usefulness of a decompiler.

(step a) f(x) is my function:

from jax import numpy as jnp
f=lambda x: jnp.log(1+jnp.exp(x)

(step b) The derivative does not give the right answer:
print(grad(f)(100.)) # nan <- expected 1

(step c) The JAXPR code of the derivative (along x axis) is:

from jax import make_jaxpr
make_jaxpr(grad(f))(100.)

output:
{ lambda ; a:f32[]. let
b:f32[] = exp a
c:f32[] = add 1.0 b
_:f32[] = log c
d:f32[] = div 1.0 c
e:f32[] = mul d b
in (e,) }

(step d) The Python equivalent (manually written) is:

df=lambda a: (1 / (1 + jnp.exp(a))) * jnp.exp(a)
print(df(100.)) # nan <- expected 1

(step e) We can easily understand the problem. I refactored a little bit the code and improved the arithmetic stability:

df=lambda x: 1 if x>10 else (jnp.exp(x) / (1 + jnp.exp(x)))
print(df(100.)) # 1. <- expected answer

A decompiler would automate step d.

@PierrickPochelu PierrickPochelu added the enhancement New feature or request label Nov 24, 2022
@jakevdp jakevdp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Nov 28, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 28, 2022

Thanks for the suggestion! It's an interesting idea. There would be some complexity involved for a full solution; off the top of my head:

  • how to decide if (sub)expressions are inlined or assigned to variables?
  • how to deal with call primitives?
  • how to deal with primitives whose jaxpr representation is quite different from the typical Python API, e.g. scatter_p, gather_p, conv_general_dilated_p, etc.?

It could be a fun challenge though 😁

@PierrickPochelu
Copy link
Author

I will release a first version working for simple code samples in next days.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 1, 2022

Cool - if you haven't seen it already, this might be a useful resource for crawling/interpreting jaxprs: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

@PierrickPochelu
Copy link
Author

I implemented a first version: https://github.com/PierrickPochelu/JaxDecompiler

For instance, it supports 23 common jaxpr operators such as "add, mul, neg, cos, sin,...". It supports also partially pmap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

2 participants