Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now compatible with JAX 0.4.34. #87

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion optimistix/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.custom_derivatives
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
Expand Down Expand Up @@ -107,6 +108,11 @@ def _for_jvp(_diff):
_, jvp_diff = jax.jvp(_for_jvp, (diff,), (t_inputs,))

t_root = (-(lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω)).ω
t_residual = tree_full_like(residual, 0)
if hasattr(jax.custom_derivatives, "zero_from_primal"):
t_residual = jax.custom_derivatives.zero_from_primal( # pyright: ignore[reportGeneralTypeIssues]
residual, symbolic_zeros=True
)
else:
t_residual = tree_full_like(residual, 0)

return (root, residual), (t_root, t_residual)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "optimistix"
version = "0.0.8"
version = "0.0.9"
description = "Nonlinear optimisation in JAX and Equinox."
readme = "README.md"
requires-python ="~=3.9"
Expand Down
Loading