Skip to content

Commit

Permalink
Now compatible with JAX 0.4.34.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 21, 2024
1 parent a8b7787 commit 316000f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion optimistix/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.custom_derivatives
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
Expand Down Expand Up @@ -107,6 +108,11 @@ def _for_jvp(_diff):
_, jvp_diff = jax.jvp(_for_jvp, (diff,), (t_inputs,))

t_root = (-(lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω)).ω
t_residual = tree_full_like(residual, 0)
if hasattr(jax.custom_derivatives, "zero_from_primal"):
t_residual = jax.custom_derivatives.zero_from_primal( # pyright: ignore[reportGeneralTypeIssues]
residual, symbolic_zeros=True
)
else:
t_residual = tree_full_like(residual, 0)

return (root, residual), (t_root, t_residual)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "optimistix"
version = "0.0.8"
version = "0.0.9"
description = "Nonlinear optimisation in JAX and Equinox."
readme = "README.md"
requires-python ="~=3.9"
Expand Down

0 comments on commit 316000f

Please sign in to comment.