Replies: 3 comments 2 replies
-
Thanks for the question! Are you running on GPU? The reason to use |
Beta Was this translation helpful? Give feedback.
-
Hi Matthew, thanks for your reply! I'm running the test in a notebook which runs on TPU VM.
|
Beta Was this translation helpful? Give feedback.
-
If your loop body To think it simple, you at least need two buffers to complete the computation: The size of each will be (512,1024,1024). At each iteration,
Unfortunately, for some reason, 3) was done via a
The first |
Beta Was this translation helpful? Give feedback.
-
Hi all, from the scan document(https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), it's expected to have some perf improvements by using scan instead of simple python loop. However interestingly I found the T5X model didn't use scan(https://github.com/google-research/t5x/blob/36e5f02f87669e3c38a9699001a4a154b514a115/t5x/examples/decoder_only/network.py#LL197C9-L197C9), so I did some experiment on scan and the result is surprising. Scan version is slower than simple loop and memory consumption is the same for scan and simple loop. Can some one help to take a look what I'm missing or my understanding is wrong? Is it because the scale of the test is too small, or sharding is not use?
The code I'm using is like following:
Memory profile shows scan and loop used same amount of memory. (Both 2GB if using timeit, 4GB if not using timeit)
Beta Was this translation helpful? Give feedback.
All reactions