Make JAX funcify more robust to constant inputs #807
ricardoV94
started this conversation in
General
Replies: 1 comment 1 reply
-
This is something we should consider doing in a more general way (e.g. in One question, and potential complication, that arises when trying to automate this: do we preserve the original |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
When compiling models for the JAX backend we sometimes see failures (i.e., concretization errors) due to operations like
Reshape
, even when the shape input is a constant and should be known at compilation time. I haven't found a MWE example that reproduces this issue, and I am not quite sure what causes it just yet, hence why this is a discussion.This is the current
jax_funcify_Reshape
:aesara/aesara/link/jax/dispatch.py
Lines 701 to 706 in 2450186
There are two hackish patterns I have seen/used to go around the failures:
shape
inputjax.ensure_compile_time_eval
(this was used by @ferrine recently):In both cases the idea is to tell JAX that the reshape is a safe operation that can be done at compilation time (or at least can be attempted).
I am not very familiar with JAX yet, to know what gotchas we might be getting into here. If someone is interested in exploring this issue, it seems like there is quite some room to increase the coverage of JAX-able graphs we can generate with Aesara by doing some more work inside those
jax_funcify*
functions.Related to #43, #68, #182, #684
Beta Was this translation helpful? Give feedback.
All reactions