🚀 google/jax - Release Notes

JAX v0.5.3 (2025-03-19)

* New Features

  * Added a `allow_negative_indices` option to `jax.lax.dynamic_slice`,
    `jax.lax.dynamic_update_slice` and related functions. The default is
    true, matching the current behavior. If set to false, JAX does not need to
    emit code clamping negative indices, which improves code size.
  * Added a `replace` option to `jax.random.categorical` to enable sampling
    without replacement.

JAX v0.5.2 (2025-03-05)

Patch release of 0.5.1

* Bug fixes
  * Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1

JAX v0.5.1 (2025-02-24)

* New Features
  * Added an experimental `jax.experimental.custom_dce.custom_dce`
    decorator to support customizing the behavior of opaque functions under
    JAX-level dead code elimination (DCE). See `#25956` for more
    details.
  * Added low-level reduction APIs in {mod}`jax.lax`: `jax.lax.reduce_sum`,
    `jax.lax.reduce_prod`, `jax.lax.reduce_max`, `jax.lax.reduce_min`,
    `jax.lax.reduce_and`, `jax.lax.reduce_or`, and `jax.lax.reduce_xor`.
  * `jax.lax.linalg.qr`, and `jax.scipy.linalg.qr`, now support
    column-pivoting on CPU and GPU. See #20282 and
    #25955 for more details.

* Changes
  * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
    env vars. Before they could only be specified via jax.config or flags.
  * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
    multi-process CPU communication works out-of-the-box.
  * The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
    This package may safely be removed if it is present on your machine; JAX now
    uses `libtpu` instead.

* Deprecations
  * The internal function `linear_util.wrap_init` and the constructor
    `core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
    a limited time, a `DeprecationWarning` is printed if
    `jax.extend.linear_util.wrap_init` is used without debugging info.
    A downstream effect of this several other internal functions need debug
    info. This change does not affect public APIs.
    See https://github.com/jax-ml/jax/issues/26480 for more detail.

* Bug fixes
  * TPU runtime startup and shutdown time should be significantly improved on
    TPU v5e and newer (from around 17s to around 8s). If not already set, you may
    need to enable transparent hugepages in your VM image
    (`sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'`).
    We hope to improve this further in future releases.
  * Persistent compilation cache no longer writes access time file if
    `JAX_COMPILATION_CACHE_MAX_SIZE` is unset or set to -1, i.e. if the LRU
    eviction policy isn't enabled. This should improve performance when using
    the cache with large-scale network storage.

JAX v0.5.0 (2025-01-17)

As of this release, JAX now uses [effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html).
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.

* Breaking changes
  * Enable `jax_threefry_partitionable` by default (see
    [the update note](https://github.com/jax-ml/jax/discussions/18480)).

  * This release drops support for Mac x86 wheels. Mac ARM of course remains
    supported. For a recent discussion, see https://github.com/jax-ml/jax/discussions/22936.

    Two key factors motivated this decision:
    * The Mac x86 build (only) has a number of test failures and crashes. We
      would prefer to ship no release than a broken release.
    * Mac x86 hardware is end-of-life and cannot be easily obtained for
      developers at this point. So it is difficult for us to fix this kind of
      problem even if we wanted to.

    We are open to readding support for Mac x86 if the community is willing
    to help support that platform: in particular, we would need the JAX test
    suite to pass cleanly on Mac x86 before we could ship releases again.

* Changes:
  * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
    supported version until June 2025.
  * The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
    supported version until June 2025.
  * `jax.numpy.einsum` now defaults to `optimize='auto'` rather than
    `optimize='optimal'`. This avoids exponentially-scaling trace-time in
    the case of many arguments (`#25214`).
  * `jax.numpy.linalg.solve` no longer supports batched 1D arguments
    on the right hand side. To recover the previous behavior in these cases,
    use `solve(a, b[..., None]).squeeze(-1)`.

* New Features
  * `jax.numpy.fft.fftn`, `jax.numpy.fft.rfftn`,
    `jax.numpy.fft.ifftn`, and `jax.numpy.fft.irfftn` now support
    transforms in more than 3 dimensions, which was previously the limit. See
    `#25606` for more details.
  * Support added for user defined state in the FFI via the new
    `jax.ffi.register_ffi_type_id` function.
  * The AOT lowering `.as_text()` method now supports the `debug_info` option
    to include debugging information, e.g., source location, in the output.

* Deprecations
  * From `jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
    are now deprecated, having been replaced by symbols of the same name
    in `jax.core`.
  * `jax.scipy.special.lpmn` and `jax.scipy.special.lpmn_values`
    are deprecated, following their deprecation in SciPy v1.15.0. There are
    no plans to replace these deprecated functions with new APIs.
  * The `jax.extend.ffi` submodule was moved to `jax.ffi`, and the
    previous import path is deprecated.

* Deletions
  * `jax_enable_memories` flag has been deleted and the behavior of that flag
    is on by default.
  * From `jax.lib.xla_client`, the previously-deprecated `Device` and
    `XlaRuntimeError` symbols have been removed; instead use `jax.Device`
    and `jax.errors.JaxRuntimeError` respectively.
  * The `jax.experimental.array_api` module has been removed after being
    deprecated in JAX v0.4.32. Since that release, `jax.numpy` supports
    the array API directly.

JAX v0.4.38 (2024-12-17)

* Changes:
  * `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
    as shortcuts of the corresponding `tree_util` functions.

* Deprecations
  * a number of APIs in the internal `jax.core` namespace have been deprecated.
    Most were no-ops, were little-used, or can be replaced by APIs of the same
    name in `jax.extend.core`; see the documentation for {mod}`jax.extend`
    for information on the compatibility guarantees of these semi-public extensions.
  * Several previously-deprecated APIs have been removed, including:
    * from `jax.core`: `check_eqn`, `check_type`,  `check_valid_jaxtype`, and
      `non_negative_dim`.
    * from `jax.lib.xla_bridge`: `xla_client` and `default_backend`.
    * from `jax.lib.xla_client`: `_xla` and `bfloat16`.
    * from `jax.numpy`: `round_`.

* New Features
  * `jax.export.export` can be used for device-polymorphic export with
    shardings constructed with {func}`jax.sharding.AbstractMesh`.
    See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
  * Added `jax.lax.split`. This is a primitive version of
    `jax.numpy.split`, added because it yields a more compact
    transpose during automatic differentiation.

JAX v0.4.37 (2024-12-10)

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

* Bug fixes
  * Fixed a bug where `jit` would error if an argument was named `f` (#25329).
  * Fix a bug that will throw `index out of range` error in
    `jax.lax.while_loop` if the user registers pytree node class with
    different aux data for the flatten and flatten_with_path.
  * Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.

JAX v0.4.36 (2024-12-05)

* Breaking Changes
  * This release lands "stackless", an internal change to JAX's tracing
    machinery. We made trace dispatch purely a function of context rather than a
    function of both context and data. This let us delete a lot of machinery for
    managing data-dependent tracing: levels, sublevels, `post_process_call`,
    `new_base_main`, `custom_bind`, and so on. The change should only affect
    users that use JAX internals.

    If you do use JAX internals then you may need to
    update your code (see
    https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
    for clues about how to do this). There might also be version skew
    issues with JAX libraries that do this. If you find this change breaks your
    non-JAX-internals-using code then try the
    `config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
    you need help updating your code then please file a bug.
  * `jax.experimental.jax2tf.convert` with `native_serialization=False`
    or with `enable_xla=False` have been deprecated since July 2024, with
    JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
    with native serialization will still be supported.
  * In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
    after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
    `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
  * The deprecated module `jax.experimental.export` has been removed. It was replaced
    by `jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
    for information on migrating to the new API.
  * The `initial` argument to `jax.nn.softmax` and `jax.nn.log_softmax`
    has been removed, after being deprecated in v0.4.27.
  * Calling `np.asarray` on typed PRNG keys (i.e. keys produced by `jax.random.key`)
    now raises an error. Previously, this returned a scalar object array.
  * The following deprecated methods and functions in `jax.export` have
    been removed:
      * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
        already.
      * `jax.export.Exported.lowering_platforms`: use `platforms`.
      * `jax.export.Exported.mlir_module_serialization_version`:
        use `calling_convention_version`.
      * `jax.export.Exported.uses_shape_polymorphism`:
         use `uses_global_constants`.
      * the `lowering_platforms` kwarg for `jax.export.export`: use
        `platforms` instead.
  * The kwargs `symbolic_scope` and `symbolic_constraints` from
    `jax.export.symbolic_args_specs` have been removed. They were
    deprecated in June 2024. Use `scope` and `constraints` instead.
  * Hashing of tracers, which has been deprecated since version 0.4.30, now
    results in a `TypeError`.
  * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
    replaces previous build.py usage. Run `python build/build.py --help` for
    more details. Brief overview of the new subcommand options:
    * `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
    * `requirements_update`: Updates requirements_lock.txt files.
  * `jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
    inputs. To recover the previous behavior, you can call `jax.numpy.ravel`
    on the function inputs.
  * `jax.scipy.special.gamma` and `jax.scipy.special.gammasgn` now
    return NaN for negative integer inputs, to match the behavior of SciPy from
    https://github.com/scipy/scipy/pull/21827.
  * `jax.clear_backends` was removed after being deprecated in v0.4.26.
  * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
    call that we guarantee export stability. This is because this custom call
    relies on Triton IR, which is not guaranteed to be stable. If you need
    to export code that uses this custom call, you can use the `disabled_checks`
    parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).

* New Features
  * `jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
    passing compilation options to XLA. For the moment it's undocumented and
    may be in flux.
  * `jax.tree_util.register_dataclass` now allows metadata fields to be
    declared inline via `dataclasses.field`. See the function documentation
    for examples.
  * Added `jax.numpy.put_along_axis`.
  * `jax.lax.linalg.eig` and the related `jax.numpy` functions
    (`jax.numpy.linalg.eig` and `jax.numpy.linalg.eigvals`) are now
    supported on GPU. See #24663 for more details.
  * Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively.  Valid values are between -1.0 and 1.0, default is 0.0.

* Bug fixes
  * Fixed a bug where the GPU implementations of LU and QR decomposition would
    result in an indexing overflow for batch sizes close to int32 max. See
    #24843 for more details.

* Deprecations
  * `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
    use `jax.Array` instead.
  * `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
    instead.

JAX v0.4.35 (2024-10-22)

* Breaking Changes
  * `jax.numpy.isscalar` now returns True for any array-like object with
    zero dimensions. Previously it only returned True for zero-dimensional
    array-like objects with a weak dtype.
  * `jax.experimental.host_callback` has been deprecated since March 2024, with
    JAX version 0.4.26. Now we removed it.
    See `#20385` for a discussion of alternatives.

* Changes:
  * `jax.lax.FftType` was introduced as a public name for the enum of FFT
    operations. The semi-public API `jax.lib.xla_client.FftType` has been
    deprecated.
  * TPU: JAX now installs TPU support from the `libtpu` package rather than
    `libtpu-nightly`. For the next few releases JAX will pin an empty version of
    `libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
    will be removed in Q1 2025.

* Deprecations:
  * The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.
    No JAX APIs consume this type, so there is no replacement.
  * The default behavior of `jax.pure_callback` and
    `jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
    the `vectorized` parameter to those functions. The `vmap_method` parameter
    should be used instead for better defined behavior. See the discussion in
    `#23881` for more details.
  * The semi-public API `jax.lib.xla_client.register_custom_call_target` has
    been deprecated. Use the JAX FFI instead.
  * The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
    `jax.lib.xla_client.ops`, 
    `jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
    `jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
    `jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
    instead.

JAX v0.4.34 (2024-10-04)

* New Functionality
  * This release includes wheels for Python 3.13. Free-threading mode is not yet
    supported.
  * `jax.errors.JaxRuntimeError` has been added as a public alias for the
    formerly private `XlaRuntimeError` type.

* Breaking changes
  * `jax_pmap_no_rank_reduction` flag is set to `True` by default.
    * `array[0]` on a pmap result now introduces a reshape (use `array[0:1]`
      instead).
    * The per-shard shape (accessable via `jax_array.addressable_shards` or
      `jax_array.addressable_data(0))` now has a leading `(1, ...)`. Update code
      that directly accesses shards accordingly. The rank of the per-shard-shape
      now matches that of the global shape which is the same behavior as jit.
      This avoids costly reshapes when passing results from pmap into jit.
  * `jax.experimental.host_callback` has been deprecated since March 2024, with
    JAX version 0.4.26. Now we set the default value of the
    `--jax_host_callback_legacy` configuration value to `True`, which means that
    if your code uses `jax.experimental.host_callback` APIs, those API calls
    will be implemented in terms of the new `jax.experimental.io_callback` API.
    If this breaks your code, for a very limited time, you can set the
    `--jax_host_callback_legacy` to `True`. Soon we will remove that
    configuration option, so you should instead transition to using the
    new JAX callback APIs. See #20385 for a discussion.

* Deprecations
  * In `jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
    arguments with `ndim != 1` are now deprecated, and in the future will result
    in an error.
  * Internal pretty-printing tools `jax.core.pp_*` have been removed, after
    being deprecated in JAX v0.4.30.
  * `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
  * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
    `jax.errors.JaxRuntimeError` instead.

* Deletion:
  * `jax.xla_computation` is deleted. It has been 3 months since its deprecation
    in 0.4.30 JAX release.
    Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
    * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
      `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
    * You can also use `.out_info` property of `jax.stages.Lowered` to get the
      output information (like tree structure, shape and dtype).
    * For cross-backend lowering, you can replace
      `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
      `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
  * `jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
    The argument was only used by `xmap` which was removed in 0.4.31.
  * `jax.tree.map(f, None, non-None)`, which previously emitted a
    `DeprecationWarning`, now raises an error. `None`
    is only a tree-prefix of itself. To preserve the current behavior, you can
    ask `jax.tree.map` to treat `None` as a leaf value by writing:
    `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
  * `jax.sharding.XLACompatibleSharding` has been removed. Please use
    `jax.sharding.Sharding`.

* Bug fixes
  * Fixed a bug where `jax.numpy.cumsum` would produce incorrect outputs
    if a non-boolean input was provided and `dtype=bool` was specified.
  * Edit implementation of `jax.numpy.ldexp` to get correct gradient.

JAX release v0.4.33 (2024-09-16)

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.

This release fixes that issue by pinning a fixed version of `libtpu-nightly`.

This release also fixes an inaccurate result for F64 tanh on CPU (#23590).

Jaxlib release v0.4.32 (2024-09-11)

WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job

JAX release v0.4.32 (2024-09-11)

WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job

Jaxlib release v0.4.31 (2024-07-30)

No notes available

JAX release v0.4.31 (2024-07-30)

No notes available

Jaxlib release v0.4.30 (2024-06-18)

No notes available

Jax release v0.4.30 (2024-06-18)

No notes available

Jaxlib release v0.4.29 (2024-06-10)

* Bug fixes
  * Fixed a bug where XLA sharded some concatenation operations incorrectly,
    which manifested as an incorrect output for cumulative reductions (#21403).
  * Fixed a bug where XLA:CPU miscompiled certain matmul fusions
    (https://github.com/openxla/xla/pull/13301).
  * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).

* Deprecations
  * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
    raise an error in a future version of jax. `None` is only a tree-prefix of
    itself. To preserve the current behavior, you can ask `jax.tree.map` to
    treat `None` as a leaf value by writing:
    `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.

JAX v0.4.29 (2024-06-10)

* Changes
  * We anticipate that this will be the last release of JAX and jaxlib
    supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
    plugin jaxlib (e.g. `pip install jax[cuda12]`).
  * JAX now requires ml_dtypes version 0.4.0 or newer.
  * Removed backwards-compatibility support for old usage of the
    `jax.experimental.export` API. It is not possible anymore to use
    `from jax.experimental.export import export`, and instead you should use
    `from jax.experimental import export`.
    The removed functionality has been deprecated since 0.4.24.

* Deprecations
  * `jax.sharding.XLACompatibleSharding` is deprecated. Please use
    `jax.sharding.Sharding`.
  * `jax.experimental.Exported.in_shardings` has been renamed as
    `jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`.
    The old names will be removed after 3 months.
  * Removed a number of previously-deprecated APIs:
    * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
    * from {mod}`jax.lax`: `tie_in`
    * from {mod}`jax.nn`: `normalize`
    * from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
      `translations`, `register_translation`, `xla_destructure`,
      `TranslationRule`, `TranslationContext`, `XlaOp`.
  * The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being
    deprecated and will soon be removed. Use `rtol` instead.
  * The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being
    deprecated and will soon be removed. Use `rtol` instead.
  * The deprecated `jax.config` submodule has been removed. To configure JAX
    use `import jax` and then reference the config object via `jax.config`.
  * {mod}`jax.random` APIs no longer accept batched keys, where previously
    some did unintentionally. Going forward, we recommend explicit use of
    {func}`jax.vmap` in such cases.

* New Functionality
  * Added {func}`jax.experimental.Exported.in_shardings_jax` to construct
    shardings that can be used with the JAX APIs from the HloShardings
    that are stored in the `Exported` objects.

jaxlib v0.4.28 (2024-05-09)

* Bug fixes
  * Fixes a memory corruption bug in the type name of Array and JIT Python
    objects in Python 3.10 or earlier.
  * Fixed a warning `'+ptx84' is not a recognized feature for this target`
    under CUDA 12.4.
  * Fixed a slow compilation problem on CPU.

* Changes
  * The Windows build is now built with Clang instead of MSVC.

JAX v0.4.28 (2024-05-09)

* Bug fixes
  * Reverted a change to `make_jaxpr` that was breaking Equinox (#21116).

* Deprecations & removals
  * The ``kind`` argument to `jax.numpy.sort` and `jax.numpy.argsort`
    is now removed. Use `stable=True` or `stable=False` instead.
  * Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
    module. Use the ``compute_capability`` attribute of a GPU device, returned
    by `jax.devices` or `jax.local_devices`, instead.

* Changes
  * The minimum jaxlib version of this release is 0.4.27.

Jaxlib release v0.4.27 (2024-05-07)

No notes available

Jax release v0.4.27 (2024-05-07)

No notes available

Jaxlib release v0.4.26 (2024-04-03)

No notes available

Jax release v0.4.26 (2024-04-03)

No notes available

Jaxlib release v0.4.25 (2024-02-26)

No notes available

JAX release v0.4.25 (2024-02-26)

No notes available

jaxlib release v0.4.24 (2024-02-06)

jaxlib release v0.4.24

JAX release v0.4.24 (2024-02-06)

JAX release v0.4.24

Jaxlib release v0.4.23 (2023-12-14)

No notes available

Jax release v0.4.23 (2023-12-14)

No notes available