Each iteration of lax.scan
requires a kernel launch on GPU backends. Can this be resolved?
#22611
Unanswered
carlosgmartin
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Copied to here, since the XLA repo might be a better place to discuss this.
My understanding is that, currently, each iteration of
lax.scan
requires a kernel launch on GPU backends. This causes an appreciable performance penalty.For context, consider the following comments:
July 2020:
March 2021:
July 2021:
May 2023:
September 2023:
February 2024:
My question is this: Is this a fundamental limitation of JAX, XLA, and/or GPU hardware? Can it be resolved? The first two comments above suggest it's possible. If so, is this currently being discussed or worked on somewhere?
Beta Was this translation helpful? Give feedback.
All reactions