Skip to content

Commit

Permalink
Add "online newton step" optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
eserie committed Aug 25, 2022
1 parent 0fa805b commit fa028f5
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 0 deletions.
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from optax._src.alias import lars
from optax._src.alias import MaskOrFn
from optax._src.alias import noisy_sgd
from optax._src.alias import online_newton_step
from optax._src.alias import optimistic_gradient_descent
from optax._src.alias import radam
from optax._src.alias import rmsprop
Expand Down Expand Up @@ -251,6 +252,7 @@
"MultiTransformState",
"noisy_sgd",
"NonNegativeParamsState",
"online_newton_step",
"OptState",
"Params",
"pathwise_jacobians",
Expand Down
23 changes: 23 additions & 0 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,29 @@ def noisy_sgd(
)


def online_newton_step(learning_rate: ScalarOrSchedule,
eps: float) -> base.GradientTransformation:
# pylint: disable=line-too-long
"""Online Newton Step optimizer.
(see the description of the ONS in Fig. 2 p. 176 of the reference below).
References:
Hazan, E., Agarwal, A. and Kale, S., 2007. Logarithmic regret algorithms for online convex optimization. Machine Learning, 69(2-3), pp.169-192 : https://link.springer.com/content/pdf/10.1007/s10994-007-5016-8.pdf
Args:
learning_rate: this is a fixed global scaling factor.
initial_accumulator_value: initialisation for the accumulator.
eps: A floating point value to avoid zero denominator.
Returns:
the corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_online_newton_step(eps=eps),
_scale_by_learning_rate(learning_rate),
)


def optimistic_gradient_descent(
learning_rate: ScalarOrSchedule,
alpha: ScalarOrSchedule = 1.0,
Expand Down
19 changes: 19 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ def test_explicit_dtype(self, dtype):
adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, adam_state.mu.dtype)

def test_online_newton_step_decreases_loss(self):
"""We check that the online newton step update decrease a loss function.
"""

def loss(w):
return -(w * x).sum() + (w ** 2).sum()

w = jnp.array([-3.0, 0.0, 2.0])
x = jnp.array([2.0, -3.0, 5.0])

opt = alias.online_newton_step(learning_rate=1.0e-3, eps=1.)
grads = jax.grad(loss)(w)

state = opt.init(w)
grads, state = opt.update(grads, state)
new_w = update.apply_updates(w, grads)

self.assertLess(loss(new_w), loss(w))


if __name__ == '__main__':
absltest.main()
51 changes: 51 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,57 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByOnlineNewtonStepState(NamedTuple):
"""State holding the inverse of the sum of gradient outer products to date."""
hessian_inv: base.Updates


def sherman_morrison(a_inv, u):
"""Sherman Morrison formula to compute the inverse of the sum of an invertible
matrix and the outer product of two vectors.
The formula is used in the "Online Newton Step" gradient update method.
"""
den = 1.0 + (u.T @ a_inv @ u)
a_inv -= a_inv @ jnp.outer(u, u) @ a_inv / den
return a_inv


def scale_by_online_newton_step(eps: float) -> base.GradientTransformation:
# pylint: disable=line-too-long
"""Rescale the updates by multiplying them by the inverse of a hessian approximation
(see the description of the ONS in Fig. 2 p. 176 of the reference below).
References:
Hazan, E., Agarwal, A. and Kale, S., 2007. Logarithmic regret algorithms for online convex optimization. Machine Learning, 69(2-3), pp.169-192 : https://link.springer.com/content/pdf/10.1007/s10994-007-5016-8.pdf
Args:
eps: A floating point value to avoid zero denominator.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
hessian_inv = jax.tree_map(
lambda t: jnp.eye(len(t.flatten()), dtype=t.dtype) / eps, params
)
return ScaleByOnlineNewtonStepState(hessian_inv=hessian_inv)

def update_fn(updates, state, params=None):
del params
flattened_updates = jax.tree_map(lambda x: x.flatten(), updates)
hessian_inv = jax.tree_multimap(
sherman_morrison, state.hessian_inv, flattened_updates
)
flattened_updates = jax.tree_multimap(lambda hinv, g: hinv @ g, hessian_inv, flattened_updates)
updates = jax.tree_multimap(lambda flat_u, u: flat_u.reshape(u.shape),
flattened_updates, updates)
return updates, ScaleByOnlineNewtonStepState(hessian_inv=hessian_inv)

return base.GradientTransformation(init_fn, update_fn)


def scale_by_optimistic_gradient(
alpha: float = 1.0,
beta: float = 1.0) -> base.GradientTransformation:
Expand Down
78 changes: 78 additions & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optax._src import combine
from optax._src import transform
from optax._src import update
from optax._src.transform import scale_by_online_newton_step

STEPS = 50
LR = 1e-2
Expand Down Expand Up @@ -269,6 +270,83 @@ def test_add_noise_has_correct_variance_scaling(self):

chex.assert_tree_all_close(updates_i, updates_i_rescaled, rtol=1e-4)

def test_sherman_morrison(self):
rng = jax.random.PRNGKey(42)
rng, sub_rng = jax.random.split(rng)
a = jax.random.normal(sub_rng, (3, 3))

rng, sub_rng = jax.random.split(rng)
u = jax.random.normal(sub_rng, (3,))

ref = jnp.linalg.inv(a + jnp.outer(u, u))
a_inv = jnp.linalg.inv(a)
a_inv = transform.sherman_morrison(a_inv, u)
np.testing.assert_allclose(a_inv, ref, rtol=1e-5)

def test_scale_by_online_newton_step(self):
eps = 5.
step = scale_by_online_newton_step(eps)
rng = jax.random.PRNGKey(42)

rng, sub_rng = jax.random.split(rng)
w = jax.random.normal(sub_rng, (3,))

params = {'weights': w}

def fun(params, x):
return 0.5 * (params['weights'] @ x)**2

rng, sub_rng = jax.random.split(rng)
x = jax.random.normal(sub_rng, (3,))

grad = jax.grad(fun)(params, x)

state = step.init(params)
updates, state = step.update(grad, state)

u = (w @ x) * x

# check that step.update applied the sherman_morrison
ref = transform.sherman_morrison(jnp.eye(3) / eps, u) @ u
np.testing.assert_allclose(ref, updates['weights'], rtol=1e-6)

# check Sherman-Morrison formula
ref = jnp.linalg.inv(jnp.eye(3) * eps + jnp.outer(u, u)) @ u
np.testing.assert_allclose(ref, updates['weights'], rtol=1e-6)

def test_scale_by_online_newton_step_with_multidimentional_weights(self):
eps = 5.
step = scale_by_online_newton_step(eps)
rng = jax.random.PRNGKey(42)

rng, sub_rng = jax.random.split(rng)
w = jax.random.normal(sub_rng, (3, 2))

params = {'weights': w}

def fun(params, x):
return 0.5 * jnp.linalg.norm(params['weights'] * x)**2

rng, sub_rng = jax.random.split(rng)
x = jax.random.normal(sub_rng, (3, 2))

grad = jax.grad(fun)(params, x)

state = step.init(params)
updates, state = step.update(grad, state)

u = grad['weights'].flatten() # (w * x).sum() * x

# check that step.update applied the sherman_morrison
ref = transform.sherman_morrison(jnp.eye(3*2) / eps, u)@u
ref = ref.reshape((3, 2))
np.testing.assert_allclose(ref, updates['weights'], rtol=1e-6)

# check Sherman-Morrison formula
ref = jnp.linalg.inv(jnp.eye(3*2) * eps + jnp.outer(u, u)) @ u
ref = ref.reshape((3, 2))
np.testing.assert_allclose(ref, updates['weights'], rtol=1e-6)

def test_scale_by_optimistic_gradient(self):

def f(params: jnp.ndarray) -> jnp.ndarray:
Expand Down

0 comments on commit fa028f5

Please sign in to comment.