-
Hi all, I've been using Jax for a while now. And I'm enjoying it very much ! I'm working on a project where the size of my input tensors is progressively increased. I recently watched this talk on dynamic shape support in PyTorch 2.0: https://youtu.be/rn-kJQ-7JmQ. They appear to use Sympy to propagate symbolic tensor shapes across the program, and they show improvements over static shapes and the inevitable recompilation. Obviously JAX is a different beast with a strict set of requirements that XLA demands, but things change; and similar to #14700 and #14634, I wonder if such a feature (or anything remotely close) will be possible soon. I suspect the help you can provide is problem-specific, but right now, what are some of the options JAX offers to overcome this recurring demand for dynamic shapes ? Thanks :), |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 3 replies
-
This is a hard question to answer in general. Can you give a self-contained example of the kind of operation you have in mind? |
Beta Was this translation helpful? Give feedback.
-
Hi @jakevdp, thanks for your reply. Below is a minimal working example that captures my problem. Essentially, what are some of the techniques people use to get around recompilation ? import jax
@jax.jit
def project(data, basis): ## This function does a lot more in my case, and jit-compiles for much longer
return data @ basis
key = jax.random.PRNGKey(42)
data = jax.random.normal(shape=(100, 1000), key=key)
for i in range(1, 11):
basis = jax.random.normal(shape=(1000, i), key=key)
reduced_data = project(data, basis)
print(project._cache_size()) |
Beta Was this translation helpful? Give feedback.
-
Shape polymorphism today only handles dynamic shapes for the lowering to HLO, producing HLO with dynamic shapes. It does this by propagating symbolic shapes through the problem, but we do not have a way to avoid recompilation, because we do not have a capable-enough compiler that can handle dynamic shapes. |
Beta Was this translation helpful? Give feedback.
-
Thanks both @jakevdp and @gnecula. I think even though recompilation is still a concern, shape polymorphism is what I'm after. I'll look at the suggested issues and hope to get some experimental documentation on this soon ? The padding suggestion also looks good, although not entirely viable in my full use-case since the maximum size of the basis (which I don't fully control and which would lead to padding elsewhere including in a neural net) often gets prohibitively large. However, I think making that maximum size dynamic can improve things. Thanks again. I'll keep an eye on the issues for developments. |
Beta Was this translation helpful? Give feedback.
-
@gnecula PyTorch 2.1 and dynamo now has dynamic shape support, leveraging CPython's Frame Evaluation to capture graphs and retire the legacy tracing manner. also it can capture graph with branch like:
Any plan or discussion in jax for that? |
Beta Was this translation helpful? Give feedback.
For this kind of dynamic shapes, there is some experimental work in shape polymorphic compilation (see the list of PRs and issues here: https://github.com/google/jax/issues?q=shape_poly+), generally available in the context of jax2tf-based execution. @gnecula may have more to say about whether that's ready for general use.
Aside from that, the main supported strategy for this kind of thing would involve padding your inputs so that all calls to the function have the same shape. That may or may not be a viable option, depending on the details of your use-case. In this simple example, it might look like this: