Skip to content

Commit

Permalink
Switched from reverse mode to forward mode where possible.
Browse files Browse the repository at this point in the history
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
  • Loading branch information
patrick-kidger committed Aug 17, 2024
1 parent dcfcbc9 commit 4fc86e0
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 38 deletions.
8 changes: 8 additions & 0 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import warnings
from collections.abc import Callable
from typing import Any, Generic, Optional, TYPE_CHECKING

Expand Down Expand Up @@ -323,6 +324,13 @@ def iterative_solve(
An [`optimistix.Solution`][] object.
"""

if any(jnp.iscomplexobj(x) for x in jtu.tree_leaves((y0, f_struct))):
warnings.warn(
"Complex support in Optimistix is a work in progress, and may still return "
"incorrect results. You may prefer to split your problem into real and "
"imaginary parts, so that Optimistix sees only real numbers."
)

f_struct = jtu.tree_map(eqxi.Static, f_struct)
aux_struct = jtu.tree_map(eqxi.Static, aux_struct)
inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags
Expand Down
56 changes: 41 additions & 15 deletions optimistix/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import ClassVar, Generic, Type, TypeVar

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
from jaxtyping import Array, Bool, Scalar

Expand All @@ -35,7 +37,7 @@
SearchState,
Y,
)
from ._misc import sum_squares
from ._misc import sum_squares, tree_dot
from ._solution import RESULTS


Expand Down Expand Up @@ -89,6 +91,9 @@ class EvalGrad(FunctionInfo, Generic[Y], strict=True):
def as_min(self):
return self.f

def compute_grad_dot(self, y: Y):
return tree_dot(self.grad, y)


# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
class EvalGradHessian(FunctionInfo, Generic[Y], strict=True):
Expand All @@ -104,6 +109,9 @@ class EvalGradHessian(FunctionInfo, Generic[Y], strict=True):
def as_min(self):
return self.f

def compute_grad_dot(self, y: Y):
return tree_dot(self.grad, y)


# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
class EvalGradHessianInv(FunctionInfo, Generic[Y], strict=True):
Expand All @@ -118,6 +126,9 @@ class EvalGradHessianInv(FunctionInfo, Generic[Y], strict=True):
def as_min(self):
return self.f

def compute_grad_dot(self, y: Y):
return tree_dot(self.grad, y)


# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
class Residual(FunctionInfo, Generic[Out], strict=True):
Expand All @@ -134,28 +145,43 @@ def as_min(self):
# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
class ResidualJac(FunctionInfo, Generic[Y, Out], strict=True):
"""Records the Jacobian `d(fn)/dy` as a linear operator. Used for least squares
problems, for which `fn` returns residuals. Has `.residual` and `.jac` and `.grad`
attributes, where `residual = fn(y)`, `jac = d(fn)/dy` and
`grad = jac^T residual`.
Takes just `residual` and `jac` as `__init__`-time arguments, from which `grad` is
computed.
problems, for which `fn` returns residuals. Has `.residual` and `.jac` attributes,
where `residual = fn(y)`, `jac = d(fn)/dy`.
"""

residual: Out
jac: lx.AbstractLinearOperator
grad: Y

def __init__(self, residual: Out, jac: lx.AbstractLinearOperator):
self.residual = residual
self.jac = jac
# The gradient is used ubiquitously, so compute it once here, so that it can be
# used without recomputation in both the descent and search.
self.grad = jac.transpose().mv(residual)

def as_min(self):
return 0.5 * sum_squares(self.residual)

def compute_grad(self):
conj_residual = jtu.tree_map(jnp.conj, self.residual)
return self.jac.transpose().mv(conj_residual)

def compute_grad_dot(self, y: Y):
# If `self.jac` is a `lx.JacobianLinearOperator` (or a
# `lx.FunctionLinearOperator` wrapping the result of `jax.linearize`), then
# `min = 0.5 * residual^2`, so `grad = jac^T residual`, i.e. the gradient of
# this. So that what we want to compute is `residual^T jac y`. Doing the
# reduction in this order means we hit forward-mode rather than reverse-mode
# autodiff.
#
# For the complex case: in this case then actually
# `min = 0.5 * residual residual^bar`
# which implies
# `grad = jac^T residual^bar`
# and thus that we want
# `grad^T^bar y = residual^T jac^bar y = (jac y^bar)^T^bar residual`.
# Notes:
# (a) the `grad` derivation is not super obvious. Note that
# `grad(z -> 0.5 z z^bar)` is `z^bar` in JAX (yes, twice the Wirtinger
# derivative!) It uses a non-Wirtinger derivative for nonholomorphic functions:
# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#complex-numbers-and-differentiation
# (b) our convention is that the first term of a dot product gets the conjugate:
# https://github.com/patrick-kidger/diffrax/pull/454#issuecomment-2210296643
return tree_dot(self.jac.mv(jtu.tree_map(jnp.conj, y)), self.residual)


Eval.__qualname__ = "FunctionInfo.Eval"
EvalGrad.__qualname__ = "FunctionInfo.EvalGrad"
Expand Down
11 changes: 3 additions & 8 deletions optimistix/_solver/backtracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from jaxtyping import Array, Bool, Scalar, ScalarLike

from .._custom_types import Y
from .._misc import (
tree_dot,
)
from .._search import AbstractSearch, FunctionInfo
from .._solution import RESULTS

Expand Down Expand Up @@ -55,7 +52,7 @@ def __post_init__(self):
)

def init(self, y: Y, f_info_struct: _FnInfo) -> _BacktrackingState:
del f_info_struct
del y, f_info_struct
return _BacktrackingState(step_size=jnp.array(self.step_init))

def step(
Expand All @@ -67,7 +64,7 @@ def step(
f_eval_info: _FnEvalInfo,
state: _BacktrackingState,
) -> tuple[Scalar, Bool[Array, ""], RESULTS, _BacktrackingState]:
if isinstance(
if not isinstance(
f_info,
(
FunctionInfo.EvalGrad,
Expand All @@ -76,16 +73,14 @@ def step(
FunctionInfo.ResidualJac,
),
):
grad = f_info.grad
else:
raise ValueError(
"Cannot use `BacktrackingArmijo` with this solver. This is because "
"`BacktrackingArmijo` requires gradients of the target function, but "
"this solver does not evaluate such gradients."
)

y_diff = (y_eval**ω - y**ω).ω
predicted_reduction = tree_dot(grad, y_diff)
predicted_reduction = f_info.compute_grad_dot(y_diff)
# Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)`
# must do better than its linear approximation:
# `fn(y_eval) < fn(y) + grad•y_diff`
Expand Down
12 changes: 7 additions & 5 deletions optimistix/_solver/dogleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ def query(
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.ResidualJac],
state: _DoglegDescentState,
) -> _DoglegDescentState:
del state
del y, state
# Compute `denom = grad^T Hess grad.`
if isinstance(f_info, FunctionInfo.EvalGradHessian):
denom = tree_dot(f_info.grad, f_info.hessian.mv(f_info.grad))
grad = f_info.grad
denom = tree_dot(f_info.grad, f_info.hessian.mv(grad))
elif isinstance(f_info, FunctionInfo.ResidualJac):
# Use Gauss--Newton approximation `Hess ~ J^T J`
denom = sum_squares(f_info.jac.mv(f_info.grad))
grad = f_info.compute_grad()
denom = sum_squares(f_info.jac.mv(grad))
else:
raise ValueError(
"`DoglegDescent` can only be used with least-squares solvers, or "
Expand All @@ -88,7 +90,7 @@ def query(
denom_nonzero = denom > jnp.finfo(denom.dtype).eps
safe_denom = jnp.where(denom_nonzero, denom, 1)
# Compute `grad^T grad / (grad^T Hess grad)`
scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0)
scaling = jnp.where(denom_nonzero, sum_squares(grad) / safe_denom, 0.0)
scaling = cast(Array, scaling)

# Downhill towards the bottom of the quadratic basin.
Expand All @@ -97,7 +99,7 @@ def query(
newton_norm = self.trust_region_norm(newton_sol)

# Downhill steepest descent.
cauchy = (-scaling * f_info.grad**ω).ω
cauchy = (-scaling * grad**ω).ω
cauchy_norm = self.trust_region_norm(cauchy)

return _DoglegDescentState(
Expand Down
3 changes: 2 additions & 1 deletion optimistix/_solver/gradient_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ def query(
FunctionInfo.EvalGrad,
FunctionInfo.EvalGradHessian,
FunctionInfo.EvalGradHessianInv,
FunctionInfo.ResidualJac,
),
):
grad = f_info.grad
elif isinstance(f_info, FunctionInfo.ResidualJac):
grad = f_info.compute_grad()
else:
raise ValueError(
"Cannot use `SteepestDescent` with this solver. This is because "
Expand Down
2 changes: 1 addition & 1 deletion optimistix/_solver/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def damped_newton_step(

pred = step_size > jnp.finfo(step_size.dtype).eps
safe_step_size = jnp.where(pred, step_size, 1)
lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size).max)
lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(step_size.dtype).max)
lm_param = cast(Array, lm_param)
if isinstance(f_info, FunctionInfo.EvalGradHessian):
operator = f_info.hessian + lm_param * lx.IdentityLinearOperator(
Expand Down
16 changes: 10 additions & 6 deletions optimistix/_solver/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,19 @@ def query(
],
state: _NonlinearCGDescentState,
) -> _NonlinearCGDescentState:
if not isinstance(
del y
if isinstance(
f_info,
(
FunctionInfo.EvalGrad,
FunctionInfo.EvalGradHessian,
FunctionInfo.EvalGradHessianInv,
FunctionInfo.ResidualJac,
),
):
grad = f_info.grad
elif isinstance(f_info, FunctionInfo.ResidualJac):
grad = f_info.compute_grad()
else:
raise ValueError(
"Cannot use `NonlinearCGDescent` with this solver. This is because "
"`NonlinearCGDescent` requires gradients of the target function, but "
Expand All @@ -140,16 +144,16 @@ def query(
# Furthermore, the same mechanism handles convergence: once
# `state.{grad, y_diff} = 0`, i.e. our previous step hit a local minima, then
# on this next step we'll again just use gradient descent, and stop.
beta = self.method(f_info.grad, state.grad, state.y_diff)
neg_grad = (-(f_info.grad**ω)).ω
beta = self.method(grad, state.grad, state.y_diff)
neg_grad = (-(grad**ω)).ω
nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω
# Check if this is a descent direction. Use gradient descent if it isn't.
y_diff = tree_where(
tree_dot(f_info.grad, nonlinear_cg_direction) < 0,
tree_dot(grad, nonlinear_cg_direction) < 0,
nonlinear_cg_direction,
neg_grad,
)
return _NonlinearCGDescentState(y_diff=y_diff, grad=f_info.grad)
return _NonlinearCGDescentState(y_diff=y_diff, grad=grad)

def step(
self, step_size: Scalar, state: _NonlinearCGDescentState
Expand Down
2 changes: 1 addition & 1 deletion optimistix/_solver/trust_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def predict_reduction(
FunctionInfo.ResidualJac,
),
):
return tree_dot(f_info.grad, y_diff)
return f_info.compute_grad_dot(y_diff)
else:
raise ValueError(
"Cannot use `LinearTrustRegion` with this solver. This is because "
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/optimistix" }
dependencies = ["jax>=0.4.28", "jaxtyping>=0.2.23", "lineax>=0.0.4", "equinox>=0.11.1", "typing_extensions>=4.5.0"]
dependencies = ["jax>=0.4.28", "jaxtyping>=0.2.23", "lineax>=0.0.5", "equinox>=0.11.1", "typing_extensions>=4.5.0"]

[build-system]
requires = ["hatchling"]
Expand Down
Loading

0 comments on commit 4fc86e0

Please sign in to comment.