Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Flax NNX API #564

Open
scott-yj-yang opened this issue Dec 4, 2024 · 1 comment
Open

Support for Flax NNX API #564

scott-yj-yang opened this issue Dec 4, 2024 · 1 comment

Comments

@scott-yj-yang
Copy link

Hi Brax development team,

We are currently using the brax.training.ppo training scripts in our project and have been very happy with Brax's performance.

We noticed that the brax.training module uses the older flax.linen API for neural network definitions. Flax recently introduced the flax.nnx API, which offers a more Pythonic and streamlined approach to model development.

Are there any plans to transition the brax repository to use the nnx API in the future? This would help improve model flexibility and maintainability for projects that depend on brax.training.

Thank you for your time and consideration.

Scott

@btaba
Copy link
Collaborator

btaba commented Dec 4, 2024

Hi @scott-yj-yang there aren't immediate plans to switch to nnx as flax.linen isn't being deprecated, but we'd be happy to review clean/minimal PRs. nnx does generally look like a cleaner API.

@erikfrey any thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants