Skip to content

How to bypass the need for dynamic shapes ? #17191

Closed Answered by jakevdp
ddrous asked this question in General
Discussion options

You must be logged in to vote

For this kind of dynamic shapes, there is some experimental work in shape polymorphic compilation (see the list of PRs and issues here: https://github.com/google/jax/issues?q=shape_poly+), generally available in the context of jax2tf-based execution. @gnecula may have more to say about whether that's ready for general use.

Aside from that, the main supported strategy for this kind of thing would involve padding your inputs so that all calls to the function have the same shape. That may or may not be a viable option, depending on the details of your use-case. In this simple example, it might look like this:

for i in range(1, 11):
  basis = jax.random.normal(shape=(1000, 10), key=key)
  red…

Replies: 5 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Answer selected by ddrous
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@gnecula
Comment options

Comment options

You must be logged in to vote
1 reply
@nullhook
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
5 participants