Skip to content

JAX v0.5.0

Latest
Compare
Choose a tag to compare
@hawkinsp hawkinsp released this 17 Jan 18:27
· 139 commits to main since this release

As of this release, JAX now uses effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.

  • Breaking changes

    • Enable jax_threefry_partitionable by default (see
      the update note).

    • This release drops support for Mac x86 wheels. Mac ARM of course remains
      supported. For a recent discussion, see #22936.

      Two key factors motivated this decision:

      • The Mac x86 build (only) has a number of test failures and crashes. We
        would prefer to ship no release than a broken release.
      • Mac x86 hardware is end-of-life and cannot be easily obtained for
        developers at this point. So it is difficult for us to fix this kind of
        problem even if we wanted to.

      We are open to readding support for Mac x86 if the community is willing
      to help support that platform: in particular, we would need the JAX test
      suite to pass cleanly on Mac x86 before we could ship releases again.

  • Changes:

    • The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
      supported version until June 2025.
    • The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
      supported version until June 2025.
    • jax.numpy.einsum now defaults to optimize='auto' rather than
      optimize='optimal'. This avoids exponentially-scaling trace-time in
      the case of many arguments (#25214).
    • jax.numpy.linalg.solve no longer supports batched 1D arguments
      on the right hand side. To recover the previous behavior in these cases,
      use solve(a, b[..., None]).squeeze(-1).
  • New Features

    • jax.numpy.fft.fftn, jax.numpy.fft.rfftn,
      jax.numpy.fft.ifftn, and jax.numpy.fft.irfftn now support
      transforms in more than 3 dimensions, which was previously the limit. See
      #25606 for more details.
    • Support added for user defined state in the FFI via the new
      jax.ffi.register_ffi_type_id function.
    • The AOT lowering .as_text() method now supports the debug_info option
      to include debugging information, e.g., source location, in the output.
  • Deprecations

    • From jax.interpreters.xla, abstractify and pytype_aval_mappings
      are now deprecated, having been replaced by symbols of the same name
      in jax.core.
    • jax.scipy.special.lpmn and jax.scipy.special.lpmn_values
      are deprecated, following their deprecation in SciPy v1.15.0. There are
      no plans to replace these deprecated functions with new APIs.
    • The jax.extend.ffi submodule was moved to jax.ffi, and the
      previous import path is deprecated.
  • Deletions

    • jax_enable_memories flag has been deleted and the behavior of that flag
      is on by default.
    • From jax.lib.xla_client, the previously-deprecated Device and
      XlaRuntimeError symbols have been removed; instead use jax.Device
      and jax.errors.JaxRuntimeError respectively.
    • The jax.experimental.array_api module has been removed after being
      deprecated in JAX v0.4.32. Since that release, jax.numpy supports
      the array API directly.