orbax save/restore with 8 devices #1618
Replies: 4 comments
-
I see you're using Before saving, you can also try converting the model using |
Beta Was this translation helpful? Give feedback.
-
Thanks for your interest.
|
Beta Was this translation helpful? Give feedback.
-
I don't think you need to fully update JAX (although your Orbax version is 0.6 while head is at 0.11, so it would be ideal to bring things more up to date). The more important consideration is using That said, we are looking into a better handling of the case you've outlined. I'm hoping to provide seamless support since this issue can arise even without using pmap. |
Beta Was this translation helpful? Give feedback.
-
Ok thank, nice that I've tigger a new dev. |
Beta Was this translation helpful? Give feedback.
-
Hello,
I have setup a snippet on Colab Here
It quite difficult to follow the flax/orbax changes in the save/restore (simple) model even with following the "latest" documentation of these two packages. I have managed to cooked something but I was wandering if I'm doing the right thing using 8TPUs on Colab; for instance it semms that one can save a single instance of the Model among the 8 existing ones (ie. the use of
flax.jax_utils.unreplicate
seems necessaryAt restoration in the same environment after
one restaure 8 vesions using
but this is 8 replicated version of the model with the same instance.
It may be foreseen to behaves like that, but I wandering how one can resume a first training session, to continue the training as one may use s unique instance at the second training session? Hope that I have been clear. Any comment on the Colab snippet is welcome.
Beta Was this translation helpful? Give feedback.
All reactions