Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value and fgrad #3

Merged
merged 2 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion finitediffx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._src.fgrad import Offset, define_fdjvp, fgrad
from ._src.fgrad import Offset, define_fdjvp, fgrad, value_and_fgrad
from ._src.finite_diff import (
curl,
difference,
Expand All @@ -33,6 +33,7 @@
"laplacian",
"hessian",
"fgrad",
"value_and_fgrad",
"Offset",
"define_fdjvp",
"generate_finitediff_coeffs",
Expand Down
132 changes: 124 additions & 8 deletions finitediffx/_src/fgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
T = TypeVar("T")


__all__ = ("fgrad", "Offset", "define_fdjvp")
__all__ = ("fgrad", "Offset", "define_fdjvp", "value_and_fgrad")


class Offset(NamedTuple):
Expand Down Expand Up @@ -139,6 +139,7 @@ def fgrad(
step_size: StepsizeType | None = None,
offsets: OffsetType = Offset(accuracy=3),
derivative: int = 1,
has_aux: bool = False,
) -> Callable:
"""Finite difference derivative of a function with respect to one of its arguments.
similar to jax.grad but with finite difference approximation
Expand All @@ -154,9 +155,13 @@ def fgrad(
- `jax.Array` defining location of function evaluation points.
- `Offset` with accuracy field to automatically generate offsets.
derivative: derivative order. Defaults to 1.
has_aux: whether the function returns an auxiliary output. Defaults to False.
If True, the derivative function will return a tuple of the form:
(derivative, aux) otherwise it will return only the derivative.

Returns:
Callable: derivative of the function
Derivative of the function if `has_aux` is False, otherwise a tuple of
the form: (derivative, aux)

Example:
>>> import finitediffx as fdx
Expand All @@ -175,37 +180,132 @@ def fgrad(
>>> df(2.0, 3.0)
Array(6., dtype=float32)

"""
value_and_fgrad_func = value_and_fgrad(
func=func,
argnums=argnums,
step_size=step_size,
offsets=offsets,
derivative=derivative,
has_aux=has_aux,
)

if has_aux:

@ft.wraps(func)
def aux_wrapper(*a, **k):
(_, aux), g = value_and_fgrad_func(*a, **k)
return g, aux

else:

@ft.wraps(func)
def wrapper(*a, **k):
_, g = value_and_fgrad_func(*a, **k)
return g

return aux_wrapper if has_aux else wrapper


def value_and_fgrad(
func: Callable,
*,
argnums: int | tuple[int, ...] = 0,
step_size: StepsizeType | None = None,
offsets: OffsetType = Offset(accuracy=3),
derivative: int = 1,
has_aux: bool = False,
) -> Callable:
"""Finite difference derivative of a function with respect to one of its arguments.
similar to jax.grad but with finite difference approximation

Args:
func: function to differentiate
argnums: argument number to differentiate. Defaults to 0.
If a tuple is passed, the function is differentiated with respect to
all the arguments in the tuple.
step_size: step size for the finite difference stencil. If `None`, the step size
is set to `(2) ** (-23 / (2 * derivative))`
offsets: offsets for the finite difference stencil. Accepted types are:
- `jax.Array` defining location of function evaluation points.
- `Offset` with accuracy field to automatically generate offsets.
derivative: derivative order. Defaults to 1.
has_aux: whether the function returns an auxiliary output. Defaults to False.
If True, the derivative function will return a tuple of the form:
((value,aux), derivative) otherwise (value, derivative)

Returns:
Value and derivative of the function if `has_aux` is False.
If `has_aux` is True, the derivative function will return a tuple of the form:
((value,aux), derivative)

Example:
>>> import finitediffx as fdx
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x, y):
... return x**2 + y**2
>>> df=fdx.value_and_fgrad(
... func=f,
... argnums=1, # differentiate with respect to y
... offsets=fdx.Offset(accuracy=2) # use 2nd order accurate finite difference
... )
>>> df(2.0, 3.0)
(13.0, Array(6., dtype=float32))

"""
func.__doc__ = (
f"Finite difference derivative of {getattr(func,'__name__', func)}"
f" w.r.t {argnums=}\n\n{func.__doc__}"
)
if not isinstance(has_aux, bool):
raise TypeError(f"{type(has_aux)} not a bool")

func_ = (lambda *a, **k: func(*a, **k)[0]) if has_aux else func

if isinstance(argnums, int):
# fgrad(func, argnums=0)
kwargs = dict(length=None, derivative=derivative)
step_size = _canonicalize_step_size(step_size, **kwargs)
offsets = _canonicalize_offsets(offsets, **kwargs)

dfunc = _evaluate_func_at_shifted_steps_along_argnum(
func=func,
func=func_,
coeffs=generate_finitediff_coeffs(offsets, derivative),
offsets=offsets,
step_size=step_size,
derivative=derivative,
argnum=argnums,
)

return ft.wraps(func)(lambda *a, **k: sum(dfunc(*a, **k)))
if has_aux is True:

@ft.wraps(func)
def aux_wrapper(*a, **k):
# should give error if the function does not
# return two item tuple
value, aux = func(*a, **k)
# we can check if the value is a scalar
# but this is unnecessary restriction for the fd
return (value, aux), sum(dfunc(*a, **k))

else:

@ft.wraps(func)
def wrapper(*a, **k):
return func(*a, **k), sum(dfunc(*a, **k))

return aux_wrapper if has_aux else wrapper

if isinstance(argnums, tuple):
# fgrad(func, argnums=(0,1))
# return a tuple of derivatives evaluated at each argnum
# this behavior is similar to jax.grad(func, argnums=(...))
kwargs = dict(length=len(argnums), derivative=derivative)

dfuncs = [
dfuncs = (
_evaluate_func_at_shifted_steps_along_argnum(
func=func,
func=func_,
coeffs=generate_finitediff_coeffs(offsets_i, derivative),
offsets=offsets_i,
step_size=step_size_i,
Expand All @@ -217,8 +317,24 @@ def fgrad(
_canonicalize_step_size(step_size, **kwargs),
argnums,
)
]
return ft.wraps(func)(lambda *a, **k: tuple(sum(df(*a, **k)) for df in dfuncs))
)

if has_aux:

@ft.wraps(func)
def aux_wrapper(*a, **k):
# destructuring the tuple to ensure
# two item tuple is returned
value, aux = func(*a, **k)
return (value, aux), tuple(sum(df(*a, **k)) for df in dfuncs)

else:

@ft.wraps(func)
def wrapper(*a, **k):
return func(*a, **k), tuple(sum(df(*a, **k)) for df in dfuncs)

return aux_wrapper if has_aux else wrapper

raise ValueError(f"argnums must be an int or a tuple of ints, got {argnums}")

Expand Down
28 changes: 27 additions & 1 deletion tests/test_fgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import pytest
from jax.experimental import enable_x64

from finitediffx import Offset, define_fdjvp, fgrad, generate_finitediff_coeffs
from finitediffx import (
Offset,
define_fdjvp,
fgrad,
generate_finitediff_coeffs,
value_and_fgrad,
)


def test_generate_finitediff_coeffs():
Expand Down Expand Up @@ -190,6 +196,16 @@ def test_fgrad_argnum():
all_correct(f1r, f2r)


def test_has_aux():
def func(x):
return x**2, "value"

v, a = fgrad(func, has_aux=True)(1.0)

assert v == 2.0
assert a == "value"


def test_fgrad_error():
with pytest.raises(ValueError):
fgrad(lambda x: x, argnums=-1)
Expand Down Expand Up @@ -235,3 +251,13 @@ def numpy_func(x, y):
jnp.array(4.0),
atol=1e-3,
)


def test_value_and_fgrad():
def func(x):
return x**2, "value"

assert value_and_fgrad(func, has_aux=True)(1.0) == ((1.0, "value"), 2.0)

with pytest.raises(TypeError):
value_and_fgrad(func, has_aux="")(1.0)