You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use persistent caching in XLA on GPU to speed up the execution of my Brax code. Tracking issues in JAX this seems to be possible and I have confirmed it works for most functions on my side with no issue.
Unfortunately for my use case (getting the jacobian/hessian of the env.step wrt obs) my code exits prematurely without error when I call my persistently cached jacobian/hessian on subsequent code executions. This happens regardless of env and backend. I have included a minimal example to reproduce what I am seeing below. To reproduce my issue:
Run minimal.py - entire program will execute and ./cache will be made and populated
Run minimal.py again - second print (line 63) will not execute and program will quit prematurely
Main takeaways:
Persistent caching of jacrev (the script default) applied to my step function wrapper fails
Persistently cached hessian of my step function wrapper fails as well
Persistent caching of jacfwd applied to my step function wrapper works without issue
No issues when I don't jit (or jit without persistent caching) for any of the above use cases
Hi,
I am trying to use persistent caching in XLA on GPU to speed up the execution of my Brax code. Tracking issues in JAX this seems to be possible and I have confirmed it works for most functions on my side with no issue.
Unfortunately for my use case (getting the jacobian/hessian of the env.step wrt obs) my code exits prematurely without error when I call my persistently cached jacobian/hessian on subsequent code executions. This happens regardless of env and backend. I have included a minimal example to reproduce what I am seeing below. To reproduce my issue:
Main takeaways:
Thanks for the help!
minimal.txt
The text was updated successfully, but these errors were encountered: