Jax decompiler #13398
Labels
enhancement
New feature or request
P3 (no schedule)
We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
A jax decompiler would take jaxpr code and produce a more readable Python code. Even if some information about the original function is lost (obfuscated code) like variable names being lost. Decompilers are important tool for reverse-engineering.
Here the illustration of the usefulness of a decompiler.
(step a) f(x) is my function:
(step b) The derivative does not give the right answer:
print(grad(f)(100.)) # nan <- expected 1
(step c) The JAXPR code of the derivative (along x axis) is:
output:
{ lambda ; a:f32[]. let
b:f32[] = exp a
c:f32[] = add 1.0 b
_:f32[] = log c
d:f32[] = div 1.0 c
e:f32[] = mul d b
in (e,) }
(step d) The Python equivalent (manually written) is:
(step e) We can easily understand the problem. I refactored a little bit the code and improved the arithmetic stability:
A decompiler would automate step d.
The text was updated successfully, but these errors were encountered: