Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d6c3578

Browse files
szhengacZheng
and
Zheng
authored
Add LANS optimizer (#18620)
* add lans optimizer * fix * fix Co-authored-by: Zheng <shzheng@a483e789dd93.ant.amazon.com>
1 parent 8ee4600 commit d6c3578

File tree

10 files changed

+1302
-20
lines changed

10 files changed

+1302
-20
lines changed

python/mxnet/ndarray/contrib.py

+78
Original file line numberDiff line numberDiff line change
@@ -680,3 +680,81 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count,
680680
learning_rates=lrs,
681681
wds=wds,
682682
**kwargs)
683+
684+
685+
def multi_lans_update(weights, grads, mean, var, step_count,
686+
lrs, wds, out=None, num_tensors=0, **kwargs):
687+
"""Given a list of gradients, update weights, mean and variance of multiple tensors
688+
following LANS Optimizer implementation.
689+
690+
Parameters
691+
----------
692+
weights : List of NDArrays containing the input weights of multiple tensors
693+
694+
grads : List of NDArrays containing input gradients
695+
696+
mean : List of NDArrays containing mean of multiple tensors to be updated
697+
698+
var : List of NDArrays containing variance of multiple tensors to be updated
699+
700+
step_count : List of scalars with the number of update step for each tensor
701+
702+
lrs : List of learning rates (one for each tensor)
703+
704+
wds : List of weight decays (one for each tensor)
705+
706+
out: List of NDArrays where the updated weights will be stored
707+
708+
num_tensors : Number of NDArrays/tensors in the list
709+
"""
710+
711+
if not num_tensors:
712+
num_tensors = len(weights)
713+
temp_list = _flatten_list(zip(weights, grads, mean, var))
714+
return ndarray._internal._multi_lans_update(*temp_list,
715+
out=out,
716+
num_tensors=num_tensors,
717+
step_count=step_count,
718+
learning_rates=lrs,
719+
wds=wds,
720+
**kwargs)
721+
722+
723+
def multi_mp_lans_update(weights, grads, mean, var, weights32, step_count,
724+
lrs, wds, out=None, num_tensors=0, **kwargs):
725+
"""Given a list of gradients, update weights, mean and variance of multiple tensors
726+
following LANS Optimizer implementation, and using Mixed-Precision.
727+
728+
Parameters
729+
----------
730+
weights : List of NDArrays containing the input weights of multiple tensors
731+
732+
grads : List of NDArrays containing input gradients
733+
734+
mean : List of NDArrays containing mean of multiple tensors to be updated
735+
736+
var : List of NDArrays containing variance of multiple tensors to be updated
737+
738+
weights32 : Master copy of weights in FP32
739+
740+
step_count : List of scalars with the number of update step for each tensor
741+
742+
lrs : List of learning rates (one for each tensor)
743+
744+
wds : List of weight decays (one for each tensor)
745+
746+
out: List of NDArrays where the updated weights will be stored
747+
748+
num_tensors : Number of NDArrays/tensors in the list
749+
"""
750+
751+
if not num_tensors:
752+
num_tensors = len(weights)
753+
temp_list = _flatten_list(zip(weights, grads, mean, var, weights32))
754+
return ndarray._internal._multi_mp_lans_update(*temp_list,
755+
out=out,
756+
num_tensors=num_tensors,
757+
step_count=step_count,
758+
learning_rates=lrs,
759+
wds=wds,
760+
**kwargs)

python/mxnet/optimizer/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from . import (optimizer, contrib, updater, utils, sgd,
2020
sgld, signum, dcasgd, nag, adagrad,
2121
adadelta, adam, adamax, nadam, ftrl,
22-
ftml, lars, lamb, rmsprop)
22+
ftml, lars, lamb, rmsprop, lans)
2323
# pylint: disable=wildcard-import
2424
from .optimizer import *
2525

@@ -57,7 +57,9 @@
5757

5858
from .rmsprop import *
5959

60+
from .lans import *
61+
6062
__all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \
6163
+ signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \
6264
+ adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \
63-
+ lars.__all__ + lamb.__all__ + rmsprop.__all__
65+
+ lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__

python/mxnet/optimizer/lans.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# coding: utf-8
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
"""LANS optimizer."""
19+
from __future__ import absolute_import
20+
import numpy
21+
from ..ndarray import (zeros, clip, sqrt, where, square, ones_like,
22+
maximum, minimum)
23+
from ..ndarray.contrib import (multi_lans_update, multi_mp_lans_update)
24+
from .optimizer import Optimizer, register
25+
26+
__all__ = ['LANS']
27+
28+
29+
@register
30+
class LANS(Optimizer):
31+
"""LANS Optimizer.
32+
33+
Referenced from 'Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes'
34+
(http://arxiv.org/abs/2006.13484)
35+
36+
Parameters
37+
----------
38+
learning_rate : float, default 0.001
39+
The initial learning rate. If None, the optimization will use the
40+
learning rate from ``lr_scheduler``. If not None, it will overwrite
41+
the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
42+
is also None, then it will be set to 0.01 by default.
43+
beta1 : float, default 0.9
44+
Exponential decay rate for the first moment estimates.
45+
beta2 : float, default 0.999
46+
Exponential decay rate for the second moment estimates.
47+
epsilon : float, default 1e-6
48+
Small value to avoid division by 0.
49+
lower_bound : float, default None
50+
Lower limit of norm of weight
51+
upper_bound : float, default None
52+
Upper limit of norm of weight
53+
aggregate_num : int, default 4
54+
Number of weights to be aggregated in a list.
55+
They are passed to the optimizer for a single optimization step.
56+
In default, all the weights are aggregated.
57+
use_fused_step : bool, default True
58+
Whether or not to use fused kernels for optimizer.
59+
When use_fused_step=False, step is called,
60+
otherwise, fused_step is called.
61+
"""
62+
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
63+
lower_bound=None, upper_bound=None, aggregate_num=4, use_fused_step=True,
64+
**kwargs):
65+
assert aggregate_num <= 45,\
66+
'When use_fused_step is True, LAMB only supports aggregate_num <= 45,' \
67+
' and receives {}'.format(aggregate_num)
68+
super(LANS, self).__init__(learning_rate=learning_rate,
69+
aggregate_num=aggregate_num,
70+
use_fused_step=use_fused_step,
71+
**kwargs)
72+
self.beta1 = beta1
73+
self.beta2 = beta2
74+
self.epsilon = epsilon
75+
self.lower_bound = lower_bound
76+
self.upper_bound = upper_bound
77+
78+
def create_state(self, index, weight):
79+
stype = weight.stype
80+
return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype), # mean
81+
zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype)) # var
82+
83+
def step(self, indices, weights, grads, states):
84+
"""Perform a fused optimization step using gradients and states.
85+
Fused kernel is used for update.
86+
87+
Parameters
88+
----------
89+
indices : list of int
90+
List of unique indices of the parameters into the individual learning rates
91+
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
92+
and `set_wd_mult()`, respectively.
93+
weights : list of NDArray
94+
List of parameters to be updated.
95+
grads : list of NDArray
96+
List of gradients of the objective with respect to this parameter.
97+
states : List of any obj
98+
List of state returned by `create_state()`.
99+
"""
100+
for index, weight, grad, state in zip(indices, weights, grads, states):
101+
self._update_count(index)
102+
lr = self._get_lr(index)
103+
wd = self._get_wd(index)
104+
t = self._index_update_count[index]
105+
106+
# preprocess grad
107+
grad *= self.rescale_grad
108+
grad /= grad.norm()
109+
if self.clip_gradient is not None:
110+
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
111+
112+
# update mean, var
113+
mean, var = state
114+
mean[:] *= self.beta1
115+
mean[:] += (1. - self.beta1) * grad
116+
var[:] *= self.beta2
117+
var[:] += (1. - self.beta2) * square(grad)
118+
119+
r1 = weight.norm()
120+
if self.lower_bound is not None:
121+
r1 = maximum(r1, self.lower_bound)
122+
if self.upper_bound is not None:
123+
r1 = minimum(r1, self.upper_bound)
124+
125+
# apply bias correction
126+
coef1 = 1. - self.beta1 ** t
127+
coef2 = 1. - self.beta2 ** t
128+
mean_hat = mean / coef1
129+
var_hat = var / coef2
130+
sqrt(var_hat, out=var_hat)
131+
var_hat += self.epsilon
132+
mean_hat /= var_hat
133+
mean_hat += wd * weight
134+
135+
g = mean_hat
136+
r2 = g.norm()
137+
138+
# calculate lans_trust_ratio for first part
139+
ratio_m = r1 / r2
140+
# becomes NaN if ratio == NaN or 0, otherwise 0
141+
nan_or_zero = 1 - ratio_m / ratio_m
142+
r_m = where(nan_or_zero, ones_like(ratio_m), ratio_m)
143+
144+
# update weight using first part of the estimator
145+
g *= lr * r_m * self.beta1
146+
weight[:] -= g
147+
148+
# calculate the second part of the estimator
149+
mean_hat = grad / var_hat
150+
mean_hat += wd * weight
151+
152+
g = mean_hat
153+
r2 = g.norm()
154+
155+
# calculate lans_trust_ratio for second part
156+
ratio_g = r1 / r2
157+
# becomes NaN if ratio == NaN or 0, otherwise 0
158+
nan_or_zero = 1 - ratio_g / ratio_g
159+
r_g = where(nan_or_zero, ones_like(ratio_g), ratio_g)
160+
161+
# update weight using second part of the estimator
162+
g *= lr * r_g * (1 - self.beta1)
163+
weight[:] -= g
164+
165+
def fused_step(self, indices, weights, grads, states):
166+
"""Perform a fused optimization step using gradients and states.
167+
Fused kernel is used for update.
168+
169+
Parameters
170+
----------
171+
indices : list of int
172+
List of unique indices of the parameters into the individual learning rates
173+
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
174+
and `set_wd_mult()`, respectively.
175+
weights : list of NDArray
176+
List of parameters to be updated.
177+
grads : list of NDArray
178+
List of gradients of the objective with respect to this parameter.
179+
states : List of any obj
180+
List of state returned by `create_state()`.
181+
"""
182+
self._update_count(indices)
183+
lrs = self._get_lrs(indices)
184+
wds = self._get_wds(indices)
185+
186+
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
187+
'rescale_grad': self.rescale_grad}
188+
if self.clip_gradient:
189+
kwargs['clip_gradient'] = self.clip_gradient
190+
if self.lower_bound:
191+
kwargs['lower_bound'] = self.lower_bound
192+
if self.upper_bound:
193+
kwargs['upper_bound'] = self.upper_bound
194+
195+
step_counts = []
196+
for index in indices:
197+
step_counts.append(self._index_update_count[index])
198+
199+
multi_precision = self.multi_precision and weights[0].dtype == numpy.float16
200+
201+
if not multi_precision:
202+
mean, var = list(zip(*states))
203+
multi_lans_update(weights, grads, mean, var,
204+
out=weights, step_count=step_counts,
205+
lrs=lrs, wds=wds, **kwargs)
206+
else:
207+
weights32, mean_var = list(zip(*states))
208+
mean, var = list(zip(*mean_var))
209+
multi_mp_lans_update(weights, grads,
210+
mean, var, weights32,
211+
out=weights, step_count=step_counts,
212+
lrs=lrs, wds=wds, **kwargs)
213+
214+
def update_multi_precision(self, indices, weights, grads, states):
215+
"""Override update_multi_precision.
216+
"""
217+
if self.use_fused_step:
218+
self.update(indices, weights, grads, states)
219+
else:
220+
super(LANS, self).update_multi_precision(indices, weights, grads, states)

0 commit comments

Comments
 (0)