Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistent Caching of Jitted Functions on GPU for Brax Envs and Autodiff #394

Open
bebark opened this issue Sep 12, 2023 · 0 comments
Open

Comments

@bebark
Copy link

bebark commented Sep 12, 2023

Hi,

I am trying to use persistent caching in XLA on GPU to speed up the execution of my Brax code. Tracking issues in JAX this seems to be possible and I have confirmed it works for most functions on my side with no issue.

Unfortunately for my use case (getting the jacobian/hessian of the env.step wrt obs) my code exits prematurely without error when I call my persistently cached jacobian/hessian on subsequent code executions. This happens regardless of env and backend. I have included a minimal example to reproduce what I am seeing below. To reproduce my issue:

  1. Run minimal.py - entire program will execute and ./cache will be made and populated
  2. Run minimal.py again - second print (line 63) will not execute and program will quit prematurely

Main takeaways:

  1. Persistent caching of jacrev (the script default) applied to my step function wrapper fails
  2. Persistently cached hessian of my step function wrapper fails as well
  3. Persistent caching of jacfwd applied to my step function wrapper works without issue
  4. No issues when I don't jit (or jit without persistent caching) for any of the above use cases

Thanks for the help!

minimal.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant