diff --git a/optimistix/_ad.py b/optimistix/_ad.py index 38d6bcd..6b2f108 100644 --- a/optimistix/_ad.py +++ b/optimistix/_ad.py @@ -5,6 +5,7 @@ import equinox as eqx import equinox.internal as eqxi import jax +import jax.custom_derivatives import jax.numpy as jnp import jax.tree_util as jtu import lineax as lx @@ -107,6 +108,11 @@ def _for_jvp(_diff): _, jvp_diff = jax.jvp(_for_jvp, (diff,), (t_inputs,)) t_root = (-(lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω)).ω - t_residual = tree_full_like(residual, 0) + if hasattr(jax.custom_derivatives, "zero_from_primal"): + t_residual = jax.custom_derivatives.zero_from_primal( # pyright: ignore[reportGeneralTypeIssues] + residual, symbolic_zeros=True + ) + else: + t_residual = tree_full_like(residual, 0) return (root, residual), (t_root, t_residual) diff --git a/pyproject.toml b/pyproject.toml index e22ee43..833ace3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "optimistix" -version = "0.0.8" +version = "0.0.9" description = "Nonlinear optimisation in JAX and Equinox." readme = "README.md" requires-python ="~=3.9"