|
| 1 | +# coding: utf-8 |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 3 | +# or more contributor license agreements. See the NOTICE file |
| 4 | +# distributed with this work for additional information |
| 5 | +# regarding copyright ownership. The ASF licenses this file |
| 6 | +# to you under the Apache License, Version 2.0 (the |
| 7 | +# "License"); you may not use this file except in compliance |
| 8 | +# with the License. You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, |
| 13 | +# software distributed under the License is distributed on an |
| 14 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | +# KIND, either express or implied. See the License for the |
| 16 | +# specific language governing permissions and limitations |
| 17 | +# under the License. |
| 18 | +"""LANS optimizer.""" |
| 19 | +from __future__ import absolute_import |
| 20 | +import numpy |
| 21 | +from ..ndarray import (zeros, clip, sqrt, where, square, ones_like, |
| 22 | + maximum, minimum) |
| 23 | +from ..ndarray.contrib import (multi_lans_update, multi_mp_lans_update) |
| 24 | +from .optimizer import Optimizer, register |
| 25 | + |
| 26 | +__all__ = ['LANS'] |
| 27 | + |
| 28 | + |
| 29 | +@register |
| 30 | +class LANS(Optimizer): |
| 31 | + """LANS Optimizer. |
| 32 | +
|
| 33 | + Referenced from 'Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes' |
| 34 | + (http://arxiv.org/abs/2006.13484) |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + learning_rate : float, default 0.001 |
| 39 | + The initial learning rate. If None, the optimization will use the |
| 40 | + learning rate from ``lr_scheduler``. If not None, it will overwrite |
| 41 | + the learning rate in ``lr_scheduler``. If None and ``lr_scheduler`` |
| 42 | + is also None, then it will be set to 0.01 by default. |
| 43 | + beta1 : float, default 0.9 |
| 44 | + Exponential decay rate for the first moment estimates. |
| 45 | + beta2 : float, default 0.999 |
| 46 | + Exponential decay rate for the second moment estimates. |
| 47 | + epsilon : float, default 1e-6 |
| 48 | + Small value to avoid division by 0. |
| 49 | + lower_bound : float, default None |
| 50 | + Lower limit of norm of weight |
| 51 | + upper_bound : float, default None |
| 52 | + Upper limit of norm of weight |
| 53 | + aggregate_num : int, default 4 |
| 54 | + Number of weights to be aggregated in a list. |
| 55 | + They are passed to the optimizer for a single optimization step. |
| 56 | + In default, all the weights are aggregated. |
| 57 | + use_fused_step : bool, default True |
| 58 | + Whether or not to use fused kernels for optimizer. |
| 59 | + When use_fused_step=False, step is called, |
| 60 | + otherwise, fused_step is called. |
| 61 | + """ |
| 62 | + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, |
| 63 | + lower_bound=None, upper_bound=None, aggregate_num=4, use_fused_step=True, |
| 64 | + **kwargs): |
| 65 | + assert aggregate_num <= 45,\ |
| 66 | + 'When use_fused_step is True, LAMB only supports aggregate_num <= 45,' \ |
| 67 | + ' and receives {}'.format(aggregate_num) |
| 68 | + super(LANS, self).__init__(learning_rate=learning_rate, |
| 69 | + aggregate_num=aggregate_num, |
| 70 | + use_fused_step=use_fused_step, |
| 71 | + **kwargs) |
| 72 | + self.beta1 = beta1 |
| 73 | + self.beta2 = beta2 |
| 74 | + self.epsilon = epsilon |
| 75 | + self.lower_bound = lower_bound |
| 76 | + self.upper_bound = upper_bound |
| 77 | + |
| 78 | + def create_state(self, index, weight): |
| 79 | + stype = weight.stype |
| 80 | + return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype), # mean |
| 81 | + zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype)) # var |
| 82 | + |
| 83 | + def step(self, indices, weights, grads, states): |
| 84 | + """Perform a fused optimization step using gradients and states. |
| 85 | + Fused kernel is used for update. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + indices : list of int |
| 90 | + List of unique indices of the parameters into the individual learning rates |
| 91 | + and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` |
| 92 | + and `set_wd_mult()`, respectively. |
| 93 | + weights : list of NDArray |
| 94 | + List of parameters to be updated. |
| 95 | + grads : list of NDArray |
| 96 | + List of gradients of the objective with respect to this parameter. |
| 97 | + states : List of any obj |
| 98 | + List of state returned by `create_state()`. |
| 99 | + """ |
| 100 | + for index, weight, grad, state in zip(indices, weights, grads, states): |
| 101 | + self._update_count(index) |
| 102 | + lr = self._get_lr(index) |
| 103 | + wd = self._get_wd(index) |
| 104 | + t = self._index_update_count[index] |
| 105 | + |
| 106 | + # preprocess grad |
| 107 | + grad *= self.rescale_grad |
| 108 | + grad /= grad.norm() |
| 109 | + if self.clip_gradient is not None: |
| 110 | + grad = clip(grad, -self.clip_gradient, self.clip_gradient) |
| 111 | + |
| 112 | + # update mean, var |
| 113 | + mean, var = state |
| 114 | + mean[:] *= self.beta1 |
| 115 | + mean[:] += (1. - self.beta1) * grad |
| 116 | + var[:] *= self.beta2 |
| 117 | + var[:] += (1. - self.beta2) * square(grad) |
| 118 | + |
| 119 | + r1 = weight.norm() |
| 120 | + if self.lower_bound is not None: |
| 121 | + r1 = maximum(r1, self.lower_bound) |
| 122 | + if self.upper_bound is not None: |
| 123 | + r1 = minimum(r1, self.upper_bound) |
| 124 | + |
| 125 | + # apply bias correction |
| 126 | + coef1 = 1. - self.beta1 ** t |
| 127 | + coef2 = 1. - self.beta2 ** t |
| 128 | + mean_hat = mean / coef1 |
| 129 | + var_hat = var / coef2 |
| 130 | + sqrt(var_hat, out=var_hat) |
| 131 | + var_hat += self.epsilon |
| 132 | + mean_hat /= var_hat |
| 133 | + mean_hat += wd * weight |
| 134 | + |
| 135 | + g = mean_hat |
| 136 | + r2 = g.norm() |
| 137 | + |
| 138 | + # calculate lans_trust_ratio for first part |
| 139 | + ratio_m = r1 / r2 |
| 140 | + # becomes NaN if ratio == NaN or 0, otherwise 0 |
| 141 | + nan_or_zero = 1 - ratio_m / ratio_m |
| 142 | + r_m = where(nan_or_zero, ones_like(ratio_m), ratio_m) |
| 143 | + |
| 144 | + # update weight using first part of the estimator |
| 145 | + g *= lr * r_m * self.beta1 |
| 146 | + weight[:] -= g |
| 147 | + |
| 148 | + # calculate the second part of the estimator |
| 149 | + mean_hat = grad / var_hat |
| 150 | + mean_hat += wd * weight |
| 151 | + |
| 152 | + g = mean_hat |
| 153 | + r2 = g.norm() |
| 154 | + |
| 155 | + # calculate lans_trust_ratio for second part |
| 156 | + ratio_g = r1 / r2 |
| 157 | + # becomes NaN if ratio == NaN or 0, otherwise 0 |
| 158 | + nan_or_zero = 1 - ratio_g / ratio_g |
| 159 | + r_g = where(nan_or_zero, ones_like(ratio_g), ratio_g) |
| 160 | + |
| 161 | + # update weight using second part of the estimator |
| 162 | + g *= lr * r_g * (1 - self.beta1) |
| 163 | + weight[:] -= g |
| 164 | + |
| 165 | + def fused_step(self, indices, weights, grads, states): |
| 166 | + """Perform a fused optimization step using gradients and states. |
| 167 | + Fused kernel is used for update. |
| 168 | +
|
| 169 | + Parameters |
| 170 | + ---------- |
| 171 | + indices : list of int |
| 172 | + List of unique indices of the parameters into the individual learning rates |
| 173 | + and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` |
| 174 | + and `set_wd_mult()`, respectively. |
| 175 | + weights : list of NDArray |
| 176 | + List of parameters to be updated. |
| 177 | + grads : list of NDArray |
| 178 | + List of gradients of the objective with respect to this parameter. |
| 179 | + states : List of any obj |
| 180 | + List of state returned by `create_state()`. |
| 181 | + """ |
| 182 | + self._update_count(indices) |
| 183 | + lrs = self._get_lrs(indices) |
| 184 | + wds = self._get_wds(indices) |
| 185 | + |
| 186 | + kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, |
| 187 | + 'rescale_grad': self.rescale_grad} |
| 188 | + if self.clip_gradient: |
| 189 | + kwargs['clip_gradient'] = self.clip_gradient |
| 190 | + if self.lower_bound: |
| 191 | + kwargs['lower_bound'] = self.lower_bound |
| 192 | + if self.upper_bound: |
| 193 | + kwargs['upper_bound'] = self.upper_bound |
| 194 | + |
| 195 | + step_counts = [] |
| 196 | + for index in indices: |
| 197 | + step_counts.append(self._index_update_count[index]) |
| 198 | + |
| 199 | + multi_precision = self.multi_precision and weights[0].dtype == numpy.float16 |
| 200 | + |
| 201 | + if not multi_precision: |
| 202 | + mean, var = list(zip(*states)) |
| 203 | + multi_lans_update(weights, grads, mean, var, |
| 204 | + out=weights, step_count=step_counts, |
| 205 | + lrs=lrs, wds=wds, **kwargs) |
| 206 | + else: |
| 207 | + weights32, mean_var = list(zip(*states)) |
| 208 | + mean, var = list(zip(*mean_var)) |
| 209 | + multi_mp_lans_update(weights, grads, |
| 210 | + mean, var, weights32, |
| 211 | + out=weights, step_count=step_counts, |
| 212 | + lrs=lrs, wds=wds, **kwargs) |
| 213 | + |
| 214 | + def update_multi_precision(self, indices, weights, grads, states): |
| 215 | + """Override update_multi_precision. |
| 216 | + """ |
| 217 | + if self.use_fused_step: |
| 218 | + self.update(indices, weights, grads, states) |
| 219 | + else: |
| 220 | + super(LANS, self).update_multi_precision(indices, weights, grads, states) |
0 commit comments