Releases: jax-ml/jax
JAX v0.4.35
-
Breaking Changes
jax.numpy.isscalar
now returns True for any array-like object with
zero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.jax.experimental.host_callback
has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See#20385
for a discussion of alternatives.
-
Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFT
operations. The semi-public APIjax.lib.xla_client.FftType
has been
deprecated.- TPU: JAX now installs TPU support from the
libtpu
package rather than
libtpu-nightly
. For the next few releases JAX will pin an empty version of
libtpu-nightly
as well aslibtpu
to ease the transition; that dependency
will be removed in Q1 2025.
-
Deprecations:
- The semi-public API
jax.lib.xla_client.PaddingType
has been deprecated.
No JAX APIs consume this type, so there is no replacement. - The default behavior of
jax.pure_callback
and
jax.extend.ffi.ffi_call
undervmap
has been deprecated and so has
thevectorized
parameter to those functions. Thevmap_method
parameter
should be used instead for better defined behavior. See the discussion in
#23881
for more details. - The semi-public API
jax.lib.xla_client.register_custom_call_target
has
been deprecated. Use the JAX FFI instead. - The semi-public APIs
jax.lib.xla_client.dtype_to_etype
,
jax.lib.xla_client.ops
,
jax.lib.xla_client.shape_from_pyval
,jax.lib.xla_client.PrimitiveType
,
jax.lib.xla_client.Shape
,jax.lib.xla_client.XlaBuilder
, and
jax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLO
instead.
- The semi-public API
JAX v0.4.34
-
New Functionality
- This release includes wheels for Python 3.13. Free-threading mode is not yet
supported. jax.errors.JaxRuntimeError
has been added as a public alias for the
formerly privateXlaRuntimeError
type.
- This release includes wheels for Python 3.13. Free-threading mode is not yet
-
Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.array[0]
on a pmap result now introduces a reshape (usearray[0:1]
instead).- The per-shard shape (accessable via
jax_array.addressable_shards
or
jax_array.addressable_data(0))
now has a leading(1, ...)
. Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
--jax_host_callback_legacy
configuration value toTrue
, which means that
if your code usesjax.experimental.host_callback
APIs, those API calls
will be implemented in terms of the newjax.experimental.io_callback
API.
If this breaks your code, for a very limited time, you can set the
--jax_host_callback_legacy
toTrue
. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See #20385 for a discussion.
-
Deprecations
- In
jax.numpy.trim_zeros
, non-arraylike arguments or arraylike
arguments withndim != 1
are now deprecated, and in the future will result
in an error. - Internal pretty-printing tools
jax.core.pp_*
have been removed, after
being deprecated in JAX v0.4.30. jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Use
jax.errors.JaxRuntimeError
instead.
- In
-
Deletion:
jax.xla_computation
is deleted. It has been 3 months since its deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality asjax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced with
jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.- You can also use
.out_info
property ofjax.stages.Lowered
to get the
output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
with
jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument.
The argument was only used byxmap
which was removed in 0.4.31.jax.tree.map(f, None, non-None)
, which previously emitted a
DeprecationWarning
, now raises an error.None
is only a tree-prefix of itself. To preserve the current behavior, you can
askjax.tree.map
to treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please use
jax.sharding.Sharding
.
-
Bug fixes
- Fixed a bug where
jax.numpy.cumsum
would produce incorrect outputs
if a non-boolean input was provided anddtype=bool
was specified. - Edit implementation of
jax.numpy.ldexp
to get correct gradient.
- Fixed a bug where
JAX release v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu-nightly
.
This release also fixes an inaccurate result for F64 tanh on CPU (#23590).
Jaxlib release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
JAX release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
Jaxlib release v0.4.31
jaxlib-v0.4.31 jaxlib version 0.4.31
JAX release v0.4.31
jax-v0.4.31 jax version 0.4.31
Jaxlib release v0.4.30
jaxlib-v0.4.30 jaxlib version 0.4.30
Jax release v0.4.30
jax-v0.4.30 jax version 0.4.30
Jaxlib release v0.4.29
-
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403). - Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(openxla/xla#13301). - Fixes a compiler crash on GPU (#21396).
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
-
Deprecations
jax.tree.map(f, None, non-None)
now emits aDeprecationWarning
, and will
raise an error in a future version of jax.None
is only a tree-prefix of
itself. To preserve the current behavior, you can askjax.tree.map
to
treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.