Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pytree step_size, offsets input #5

Merged
merged 4 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 87 additions & 73 deletions finitediffx/_src/fgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

from __future__ import annotations

import dataclasses as dc
import functools as ft
from typing import Callable, NamedTuple, TypeVar, Union
from typing import Any, Callable, Sequence, TypeVar, Union

import jax
import jax.numpy as jnp
Expand All @@ -24,21 +25,32 @@

from finitediffx._src.utils import _generate_central_offsets, generate_finitediff_coeffs

__all__ = ("fgrad", "Offset", "define_fdjvp", "value_and_fgrad")

P = ParamSpec("P")
T = TypeVar("T")
constant_treedef = jtu.tree_structure(0)
PyTree = Any


__all__ = ("fgrad", "Offset", "define_fdjvp", "value_and_fgrad")
@dc.dataclass(frozen=True)
class Offset:
"""Convinience class for finite difference offsets used inside `fgrad`

Args:
accuracy: The accuracy of the finite difference scheme. Must be >=2.

class Offset(NamedTuple):
"""Convinience class for finite difference offsets used inside `fgrad`"""
Example:
>>> import finitediffx as fdx
>>> fdx.fgrad(lambda x: x**2, offsets=fdx.Offset(accuracy=2))(1.0)
Array(2., dtype=float32)
"""

accuracy: int


OffsetType = Union[jax.Array, Offset]
StepsizeType = Union[jax.Array, float]
OffsetType = Union[jax.Array, Offset, PyTree]
StepsizeType = Union[jax.Array, float, PyTree]


def _evaluate_func_at_shifted_steps_along_argnum(
Expand Down Expand Up @@ -68,69 +80,65 @@ def wrapper(*args, **kwargs):


def resolve_step_size(
step_size: StepsizeType | tuple[StepsizeType, ...] | None,
length: int | None,
step_size: StepsizeType | Sequence[StepsizeType] | None,
treedef: jtu.PyTreeDef,
derivative: int,
) -> tuple[StepsizeType, ...] | StepsizeType:
) -> Sequence[StepsizeType] | StepsizeType:
# return non-tuple values if length is None
length = treedef.num_leaves

if isinstance(step_size, (jax.Array, float)):
return ((step_size,) * length) if length else step_size
return (step_size,) * length

if step_size is None:
default = (2) ** (-23 / (2 * derivative))
return ((default,) * length) if length else default

if isinstance(step_size, tuple) and length:
assert len(step_size) == length, f"step_size must be a tuple of length {length}"
step_size = list(step_size)
for i, s in enumerate(step_size):
if s is None:
step_size[i] = (2) ** (-23 / (2 * derivative))
elif not isinstance(s, (jax.Array, float)):
raise TypeError(f"{type(s)} not in {(jax.Array, float)}")
return tuple(step_size)
return (default,) * length

step_size_leaves, step_size_treedef = jtu.tree_flatten(step_size)

if step_size_treedef == treedef:
# step_size is a pytree with the same structure as the input
return step_size_leaves

raise TypeError(
f"`step_size` must be of type:\n"
f"- `jax.Array`\n"
f"- `float`\n"
f"- tuple of `jax.Array` or `float` for tuple argnums.\n"
f"- pytree with the same structure as the desired arg.\n"
f"but got {type(step_size)=}"
)


def resolve_offsets(
offsets: tuple[OffsetType | None, ...] | OffsetType | None,
length: int,
offsets: Sequence[OffsetType | None] | OffsetType | None,
treedef: jax.tree_util.PyTreeDef,
derivative: int,
) -> tuple[OffsetType, ...] | OffsetType:
# single value
length = treedef.num_leaves

if isinstance(offsets, Offset):
if offsets.accuracy < 2:
raise ValueError(f"offset accuracy must be >=2, got {offsets.accuracy}")
offsets = jnp.array(_generate_central_offsets(derivative, offsets.accuracy))
return ((offsets,) * length) if length else offsets
return (offsets,) * length

if isinstance(offsets, jax.Array):
return ((offsets,) * length) if length else offsets

if isinstance(offsets, tuple) and length:
assert len(offsets) == length, f"`offsets` must be a tuple of length {length}"
offsets = list(offsets)
for i, o in enumerate(offsets):
if isinstance(o, Offset):
if o.accuracy < 2:
raise ValueError(f"offset accuracy must be >=2, got {o.accuracy}")
o = jnp.array(_generate_central_offsets(derivative, o.accuracy))
offsets[i] = o
elif not isinstance(o, (Offset, jax.Array)):
raise TypeError(f"{type(o)} not an Offset")

return tuple(offsets)
return (offsets,) * length

offsets_leaves, offsets_treedef = jtu.tree_flatten(offsets)

if offsets_treedef == treedef:
# offsets is a pytree with the same structure as the input
return offsets_leaves

raise TypeError(
f"`offsets` must be of type:\n"
f"- `Offset`\n"
f"- `jax.Array`\n"
f"- tuple of `Offset` or `jax.Array` for tuple argnums.\n"
f"- pytree with the same structure as the desired arg.\n"
f"but got {type(offsets)=}"
)

Expand All @@ -146,36 +154,31 @@ def _fgrad_along_argnum(
if not isinstance(argnum, int):
raise TypeError(f"argnum must be an integer, got {type(argnum)}")

def _leaves_fgrad(func: Callable, *, length: int):
kwargs = dict(length=length, derivative=derivative)
step_size_ = resolve_step_size(step_size, **kwargs)
offsets_ = resolve_offsets(offsets, **kwargs)
def wrapper(*args, **kwargs):
arg_leaves, arg_treedef = jtu.tree_flatten(args[argnum])
args_ = list(args)

dfuncs = [
def func_wrapper(*leaves):
args_[argnum] = jtu.tree_unflatten(arg_treedef, leaves)
return func(*args_, **kwargs)

spec = dict(treedef=arg_treedef, derivative=derivative)
step_size_ = resolve_step_size(step_size, **spec)
offsets_ = resolve_offsets(offsets, **spec)

flat_result = [
_evaluate_func_at_shifted_steps_along_argnum(
func=func,
func=func_wrapper,
coeffs=generate_finitediff_coeffs(oi, derivative),
offsets=oi,
step_size=si,
derivative=derivative,
argnum=i,
)
)(*arg_leaves)
for i, (oi, si) in enumerate(zip(offsets_, step_size_))
]

return lambda *a, **k: [df(*a, **k) for df in dfuncs]

def wrapper(*args, **kwargs):
arg_leaves, arg_treedef = jtu.tree_flatten(args[argnum])
args_ = list(args)

def func_wrapper(*leaves):
args_[argnum] = jtu.tree_unflatten(arg_treedef, leaves)
return func(*args_, **kwargs)

dfunc = _leaves_fgrad(func_wrapper, length=len(arg_leaves))

return jtu.tree_unflatten(arg_treedef, dfunc(*arg_leaves))
return jtu.tree_unflatten(arg_treedef, flat_result)

return wrapper

Expand All @@ -202,6 +205,8 @@ def value_and_fgrad(
offsets: offsets for the finite difference stencil. Accepted types are:
- `jax.Array` defining location of function evaluation points.
- `Offset` with accuracy field to automatically generate offsets.
- pytree of `jax.Array`/`Offset` to define offsets for each argument
of the same pytree structure as argument defined by `argnums`.
derivative: derivative order. Defaults to 1.
has_aux: whether the function returns an auxiliary output. Defaults to False.
If True, the derivative function will return a tuple of the form:
Expand Down Expand Up @@ -238,13 +243,11 @@ def value_and_fgrad(

if isinstance(argnums, int):
# fgrad(func, argnums=0)
kwargs = dict(length=None, derivative=derivative)

dfunc = _fgrad_along_argnum(
func=func_,
argnum=argnums,
step_size=resolve_step_size(step_size, **kwargs),
offsets=resolve_offsets(offsets, **kwargs),
step_size=step_size,
offsets=offsets,
derivative=derivative,
)

Expand All @@ -267,21 +270,30 @@ def wrapper(*a, **k):
# fgrad(func, argnums=(0,1))
# return a tuple of derivatives evaluated at each argnum
# this behavior is similar to jax.grad(func, argnums=(...))
kwargs = dict(length=len(argnums), derivative=derivative)
if isinstance(offsets, tuple):
if len(offsets) != len(argnums):
raise AssertionError("offsets must have the same length as argnums")
offsets_ = offsets
else:
offsets_ = (offsets,) * len(argnums)

if isinstance(step_size, tuple):
if len(step_size) != len(argnums):
raise AssertionError("step_size must have the same length as argnums")
step_size_ = step_size

else:
step_size_ = (step_size,) * len(argnums)

dfuncs = [
_fgrad_along_argnum(
func=func_,
argnum=argnum_i,
step_size=step_size_i,
offsets=offsets_i,
argnum=ai,
step_size=si,
offsets=oi,
derivative=derivative,
)
for offsets_i, step_size_i, argnum_i in zip(
resolve_offsets(offsets, **kwargs),
resolve_step_size(step_size, **kwargs),
argnums,
)
for oi, si, ai in zip(offsets_, step_size_, argnums)
]

if has_aux:
Expand Down Expand Up @@ -326,6 +338,8 @@ def fgrad(
offsets: offsets for the finite difference stencil. Accepted types are:
- `jax.Array` defining location of function evaluation points.
- `Offset` with accuracy field to automatically generate offsets.
- pytree of `jax.Array`/`Offset` to define offsets for each argument
of the same pytree structure as argument defined by `argnums`.
derivative: derivative order. Defaults to 1.
has_aux: whether the function returns an auxiliary output. Defaults to False.
If True, the derivative function will return a tuple of the form:
Expand Down Expand Up @@ -436,7 +450,7 @@ def define_fdjvp(

@func.defjvp
def _(primals, tangents):
kwargs = dict(length=len(primals), derivative=1)
kwargs = dict(treedef=jtu.tree_structure(primals), derivative=1)
step_size_ = resolve_step_size(step_size, **kwargs)
offsets_ = resolve_offsets(offsets, **kwargs)
primal_out = func(*primals)
Expand Down
4 changes: 3 additions & 1 deletion finitediffx/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def generate_finitediff_coeffs(
"""

if derivative >= len(offsets):
raise ValueError(f"{len(offsets)=} must be larger than {derivative=}.")
raise ValueError(
f"{offsets=} of {len(offsets)=} must be larger than {derivative=}."
)

return _generate_finitediff_coeffs(offsets, derivative)
Loading