Skip to content

Commit 73826fc

Browse files
sifmelcarapiiswrong
authored andcommitted
Fix RMSProp update rule (apache#6235)
* Fix RMSProp update rule Follow the formula presents in Alex's paper, this prevents taking square root of a negative value (caused by arithmetic error). * Fix the formula of non centered version of RMSProp * Fix RMSProp update rule in python test * Fix RMSProp update rule in perl test
1 parent b9d491e commit 73826fc

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

perl-package/AI-MXNet/t/test_optimizers.t

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ method update($index, $weight, $grad, $state)
166166
$grad = mx->nd->clip($grad, -$self->clip_gradient, $self->clip_gradient);
167167
}
168168
$n .= (1 - $self->gamma1) * ($grad * $grad) + $self->gamma1 * $n;
169-
$weight -= $lr * $grad/(mx->nd->sqrt($n) + $self->epsilon);
169+
$weight -= $lr * $grad/(mx->nd->sqrt($n + $self->epsilon));
170170
}
171171
else
172172
{
@@ -177,7 +177,7 @@ method update($index, $weight, $grad, $state)
177177
}
178178
$n .= (1 - $self->gamma1) * ($grad * $grad) + $self->gamma1 * $n;
179179
$g .= (1 - $self->gamma1) * $grad + $self->gamma1 * $g;
180-
$delta .= ($self->gamma2) * $delta - $lr * $grad/(mx->nd->sqrt($n - $g*$g) + $self->epsilon);
180+
$delta .= ($self->gamma2) * $delta - $lr * $grad/(mx->nd->sqrt($n - $g*$g + $self->epsilon));
181181
$weight += $delta;
182182
}
183183
if($self->clip_weights)

src/operator/optimizer_op-inl.h

+20-18
Original file line numberDiff line numberDiff line change
@@ -300,17 +300,17 @@ inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
300300
delta = scalar<DType>(param.gamma2) * delta -
301301
scalar<DType>(param.lr) *
302302
(F<clip>(grad, DType(param.clip_gradient)) /
303-
(F<square_root>(state_n - state_g * state_g) +
304-
scalar<DType>(param.epsilon)));
303+
(F<square_root>(state_n - state_g * state_g +
304+
scalar<DType>(param.epsilon))));
305305
} else {
306306
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
307307
scalar<DType>(param.gamma1) * state_n;
308308
state_g = scalar<DType>(1.f - param.gamma1) * grad +
309309
scalar<DType>(param.gamma1) * state_g;
310310
delta = scalar<DType>(param.gamma2) * delta -
311311
scalar<DType>(param.lr) *
312-
(grad / (F<square_root>(state_n - state_g * state_g) +
313-
scalar<DType>(param.epsilon)));
312+
(grad / (F<square_root>(state_n - state_g * state_g +
313+
scalar<DType>(param.epsilon))));
314314
}
315315

316316
if (param.clip_weights >= 0.0f) {
@@ -386,33 +386,35 @@ inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
386386
if (param.clip_weights >= 0.0f) {
387387
Assign(out, req[0],
388388
F<clip>(weight -
389-
scalar<DType>(param.lr) *
390-
(F<clip>(grad, DType(param.clip_gradient)) /
391-
(F<square_root>(state_n) +
392-
scalar<DType>(param.epsilon))),
389+
scalar<DType>(param.lr) *
390+
(F<clip>(grad, DType(param.clip_gradient)) /
391+
(F<square_root>(state_n +
392+
scalar<DType>(param.epsilon)))),
393393
DType(param.clip_weights)));
394394
} else {
395395
Assign(out, req[0], weight -
396-
scalar<DType>(param.lr) *
397-
(F<clip>(grad, DType(param.clip_gradient)) /
398-
(F<square_root>(state_n) +
399-
scalar<DType>(param.epsilon))));
396+
scalar<DType>(param.lr) *
397+
(F<clip>(grad, DType(param.clip_gradient)) /
398+
(F<square_root>(state_n +
399+
scalar<DType>(param.epsilon)))));
400400
}
401401
} else {
402402
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
403403
scalar<DType>(param.gamma1) * state_n;
404404
if (param.clip_weights >= 0.0f) {
405405
Assign(out, req[0],
406406
F<clip>(weight -
407-
scalar<DType>(param.lr) *
408-
(grad / (F<square_root>(state_n) +
409-
scalar<DType>(param.epsilon))),
407+
scalar<DType>(param.lr) *
408+
(grad /
409+
(F<square_root>(state_n +
410+
scalar<DType>(param.epsilon)))),
410411
DType(param.clip_weights)));
411412
} else {
412413
Assign(out, req[0], weight -
413-
scalar<DType>(param.lr) *
414-
(grad / (F<square_root>(state_n) +
415-
scalar<DType>(param.epsilon))));
414+
scalar<DType>(param.lr) *
415+
(grad /
416+
(F<square_root>(state_n +
417+
scalar<DType>(param.epsilon)))));
416418
}
417419
}
418420
});

tests/python/unittest/test_optimizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,15 @@ def update(self, index, weight, grad, state):
301301
if self.clip_gradient is not None:
302302
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
303303
n[:] = (1 - self.gamma1) * (grad * grad) + self.gamma1 * n
304-
weight[:] -= lr * grad/(mx.nd.sqrt(n) + self.epsilon)
304+
weight[:] -= lr * grad/(mx.nd.sqrt(n + self.epsilon))
305305

306306
else:
307307
n, g, delta = state
308308
if self.clip_gradient is not None:
309309
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
310310
n[:] = (1 - self.gamma1) * (grad * grad) + self.gamma1 * n
311311
g[:] = (1 - self.gamma1) * grad + self.gamma1 * g
312-
delta[:] = (self.gamma2) * delta - lr * grad/(mx.nd.sqrt(n - g*g) + self.epsilon)
312+
delta[:] = (self.gamma2) * delta - lr * grad/(mx.nd.sqrt(n - g*g + self.epsilon))
313313
weight[:] += delta
314314

315315
if self.clip_weights:

0 commit comments

Comments
 (0)