Skip to content

Commit

Permalink
Add "online newton step" optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
eserie committed Feb 13, 2022
1 parent 28f9627 commit c8b69ee
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 1 deletion.
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from optax._src.alias import lamb
from optax._src.alias import lars
from optax._src.alias import noisy_sgd
from optax._src.alias import online_newton_step
from optax._src.alias import radam
from optax._src.alias import rmsprop
from optax._src.alias import sgd
Expand Down Expand Up @@ -232,6 +233,7 @@
"MultiTransformState",
"noisy_sgd",
"NonNegativeParamsState",
"online_newton_step",
"OptState",
"Params",
"pathwise_jacobians",
Expand Down
25 changes: 24 additions & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from optax._src import privacy
from optax._src import transform
from optax._src import wrappers

from optax._src.transform import DEFAULT_ONLINE_NEWTON_STEP_EPS

ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]
Expand Down Expand Up @@ -668,3 +668,26 @@ def dpsgd(
if momentum is not None else base.identity()),
_scale_by_learning_rate(learning_rate)
)


def online_newton_step(learning_rate: ScalarOrSchedule,
eps: float = DEFAULT_ONLINE_NEWTON_STEP_EPS) -> base.GradientTransformation:
# pylint: disable=line-too-long
"""An online newton optimizer.
References:
Hazan, E., Agarwal, A. and Kale, S., 2007. Logarithmic regret algorithms for online convex optimization. Machine Learning, 69(2-3), pp.169-192 : https://link.springer.com/content/pdf/10.1007/s10994-007-5016-8.pdf
Args:
learning_rate: this is a fixed global scaling factor.
initial_accumulator_value: initialisation for the accumulator.
eps: a small constant applied to denominator inside of the square root
(as in RMSProp) to avoid dividing by zero when rescaling.
Returns:
the corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_online_newton_step(eps=eps),
_scale_by_learning_rate(learning_rate),
)
19 changes: 19 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,24 @@ def test_explicit_dtype(self, dtype):
self.assertEqual(expected_dtype, adam_state.mu.dtype)


def test_newton(self):

def loss(w, x):
return -(w * x).sum() + (w ** 2).sum()

w = jnp.ones(3)
x = jnp.ones(3)

opt = alias.online_newton_step(1.0e-3)
l_, grads = jax.value_and_grad(loss)(w, x)

state = opt.init(w)
grads, state = opt.update(grads, state)
w = update.apply_updates(w, grads)

assert loss(w, x) < l_



if __name__ == '__main__':
absltest.main()
60 changes: 60 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optax._src import utils
from optax._src import wrappers

DEFAULT_ONLINE_NEWTON_STEP_EPS = 1.
# pylint:disable=no-value-for-parameter

_abs_sq = numerics.abs_sq
Expand Down Expand Up @@ -906,6 +907,65 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByOnlineNewtonStepState(NamedTuple):
"""State holding the inverse of the sum of gradient outer products to date."""

hessian_inv: base.Updates


def sherman_morrison(a_inv, u, v):
"""Sherman Morrison formula to compute the inverse of the sum of an invertible
matrix and the outer product of two vectors.
The formula is used in the "Online Newton Step" gradient update method.
"""
den = 1.0 + (v.T @ a_inv @ u)
a_inv -= a_inv @ jnp.outer(u, v) @ a_inv / den
return a_inv


def scale_by_online_newton_step(eps: float = DEFAULT_ONLINE_NEWTON_STEP_EPS) -> base.GradientTransformation:
# pylint: disable=line-too-long
"""Rescale the updates by multiplying them by the inverse of a hessian approximation
(see the description of the ONS in Fig. 2 p. 176 of the article below).
References:
Hazan, E., Agarwal, A. and Kale, S., 2007. Logarithmic regret algorithms for online convex optimization. Machine Learning, 69(2-3), pp.169-192 : https://link.springer.com/content/pdf/10.1007/s10994-007-5016-8.pdf
Args:
eps: A small floating point value to avoid zero denominator.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
hessian_inv = jax.tree_map(
lambda t: jnp.eye(len(t.flatten()), dtype=t.dtype) / eps, params
)
return ScaleByOnlineNewtonStepState(hessian_inv=hessian_inv)

def update_fn(updates, state, params=None):
del params

class Tuple(tuple):
"""Class to avoid pytree conversion and allow for the use
of shapes in final reshape."""

shapes = jax.tree_map(lambda x: Tuple(x.shape), updates)
updates = jax.tree_map(lambda x: x.flatten(), updates)
hessian_inv = jax.tree_multimap(
lambda hinv, u: sherman_morrison(hinv, u, u), state.hessian_inv, updates
)
updates = jax.tree_multimap(lambda hinv, g: hinv @ g, hessian_inv, updates)
updates = jax.tree_multimap(lambda u, shape: u.reshape(shape), updates,
shapes)

return updates, ScaleByOnlineNewtonStepState(hessian_inv=hessian_inv)

return base.GradientTransformation(init_fn, update_fn)


# TODO(b/183800387): remove legacy aliases.
# These legacy aliases are here for checkpoint compatibility
# To be removed once checkpoints have updated.
Expand Down
48 changes: 48 additions & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optax._src import combine
from optax._src import transform
from optax._src import update
from optax._src.transform import scale_by_online_newton_step

STEPS = 50
LR = 1e-2
Expand Down Expand Up @@ -234,6 +235,53 @@ def test_add_noise_has_correct_variance_scaling(self):

chex.assert_tree_all_close(updates_i, updates_i_rescaled, rtol=1e-4)

def test_sherman_morrison(self):
rng = jax.random.PRNGKey(42)
rng, sub_rng = jax.random.split(rng)
a = jax.random.normal(sub_rng, (3, 3))

rng, sub_rng = jax.random.split(rng)
u = jax.random.normal(sub_rng, (3,))

rng, sub_rng = jax.random.split(rng)
v = jax.random.normal(sub_rng, (3,))

ref = jnp.linalg.inv(a + jnp.outer(u, v))
a_inv = jnp.linalg.inv(a)
a_inv = transform.sherman_morrison(a_inv, u, v)
assert jnp.allclose(a_inv, ref)

def test_scale_by_online_newton_step(self):
eps = 5.
step = scale_by_online_newton_step(eps)
rng = jax.random.PRNGKey(42)

rng, sub_rng = jax.random.split(rng)
w = jax.random.normal(sub_rng, (3,))

params = {"weights" : w}

def fun(params, x):
return 0.5 * (params["weights"] @ x)**2

rng, sub_rng = jax.random.split(rng)
x = jax.random.normal(sub_rng, (3,))

grad = jax.grad(fun)(params, x)

state = step.init(params)
updates, state = step.update(grad, state)

u = (w @ x) * x

# check that step.update applied the sherman_morrison
ref = transform.sherman_morrison(jnp.eye(3) / eps, u, u) @ u
assert jnp.allclose(ref, updates["weights"])

# check Sherman-Morrison formula
ref = jnp.linalg.inv(jnp.eye(3) * eps + jnp.outer(u, u)) @ u
assert jnp.allclose(ref, updates["weights"])


if __name__ == '__main__':
absltest.main()

0 comments on commit c8b69ee

Please sign in to comment.