You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
See the technique described here and here. The essence is using a JAX scan instead of a python loop to iterate over layers that have the same structure.
Motivation
The scan over layers technique allows JAX to "see" that the computational structure of each iteration is the same. This can dramatically reduce compile time and also system memory occupied by the JAX compilation cache (i.e. I believe if you have 25 layers in a model, the naive approach will end up with ~25 times as much JIT-compiled code since each layer will result in duplicative output code). My handwritten T5-like model uses ~1/50th of the system memory of the transformers Flax T5 models of similar size. It's easy to get system OOM errors with the current Flax implementation if you end up with multiple versions of the model compiled for different sequence lengths.
Your contribution
It's possible I could submit a PR for this at some point in the future, but I can't be certain.
The text was updated successfully, but these errors were encountered:
Hey @colehaus! Sorry for the late reply here. We've currently decided not to implement scan for the Flax models in Transformers. You can see a brief reason for this here: #24587 (comment)
Happy to re-open the conversation if you feel strongly about this! There was a WIP PR that shows how this could be done generally for Transformers models here: #18341
But currently I tend to view scan as a specific feature that can be built on top of the Transformers library by advanced users who require it.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Feature request
See the technique described here and here. The essence is using a JAX
scan
instead of a python loop to iterate over layers that have the same structure.Motivation
The scan over layers technique allows JAX to "see" that the computational structure of each iteration is the same. This can dramatically reduce compile time and also system memory occupied by the JAX compilation cache (i.e. I believe if you have 25 layers in a model, the naive approach will end up with ~25 times as much JIT-compiled code since each layer will result in duplicative output code). My handwritten T5-like model uses ~1/50th of the system memory of the
transformers
Flax T5 models of similar size. It's easy to get system OOM errors with the current Flax implementation if you end up with multiple versions of the model compiled for different sequence lengths.Your contribution
It's possible I could submit a PR for this at some point in the future, but I can't be certain.
The text was updated successfully, but these errors were encountered: