What's Changed
- Expose
normalize_probabilities
as a good normalization forSoftmaxLoss
. - Remove use of
initial
argument tojax.nn.softmax
andjax.nn.log_softmax
- Drop python 3.8 checks and add python 3.11 checks.
- Changes in lambda weights to reduce boilerplate and add new options.
- Fix pytype and clean up types across codebase.
- Minor typo fixes in documentation.
Full Changelog: v0.3.0...v0.4.0