@@ -300,17 +300,17 @@ inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
300
300
delta = scalar<DType>(param.gamma2 ) * delta -
301
301
scalar<DType>(param.lr ) *
302
302
(F<clip>(grad, DType (param.clip_gradient )) /
303
- (F<square_root>(state_n - state_g * state_g) +
304
- scalar<DType>(param.epsilon )));
303
+ (F<square_root>(state_n - state_g * state_g +
304
+ scalar<DType>(param.epsilon ) )));
305
305
} else {
306
306
state_n = scalar<DType>(1 .f - param.gamma1 ) * (grad * grad) +
307
307
scalar<DType>(param.gamma1 ) * state_n;
308
308
state_g = scalar<DType>(1 .f - param.gamma1 ) * grad +
309
309
scalar<DType>(param.gamma1 ) * state_g;
310
310
delta = scalar<DType>(param.gamma2 ) * delta -
311
311
scalar<DType>(param.lr ) *
312
- (grad / (F<square_root>(state_n - state_g * state_g) +
313
- scalar<DType>(param.epsilon )));
312
+ (grad / (F<square_root>(state_n - state_g * state_g +
313
+ scalar<DType>(param.epsilon ) )));
314
314
}
315
315
316
316
if (param.clip_weights >= 0 .0f ) {
@@ -386,33 +386,35 @@ inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
386
386
if (param.clip_weights >= 0 .0f ) {
387
387
Assign (out, req[0 ],
388
388
F<clip>(weight -
389
- scalar<DType>(param.lr ) *
390
- (F<clip>(grad, DType (param.clip_gradient )) /
391
- (F<square_root>(state_n) +
392
- scalar<DType>(param.epsilon ))),
389
+ scalar<DType>(param.lr ) *
390
+ (F<clip>(grad, DType (param.clip_gradient )) /
391
+ (F<square_root>(state_n +
392
+ scalar<DType>(param.epsilon ) ))),
393
393
DType (param.clip_weights )));
394
394
} else {
395
395
Assign (out, req[0 ], weight -
396
- scalar<DType>(param.lr ) *
397
- (F<clip>(grad, DType (param.clip_gradient )) /
398
- (F<square_root>(state_n) +
399
- scalar<DType>(param.epsilon ))));
396
+ scalar<DType>(param.lr ) *
397
+ (F<clip>(grad, DType (param.clip_gradient )) /
398
+ (F<square_root>(state_n +
399
+ scalar<DType>(param.epsilon ) ))));
400
400
}
401
401
} else {
402
402
state_n = scalar<DType>(1 .f - param.gamma1 ) * (grad * grad) +
403
403
scalar<DType>(param.gamma1 ) * state_n;
404
404
if (param.clip_weights >= 0 .0f ) {
405
405
Assign (out, req[0 ],
406
406
F<clip>(weight -
407
- scalar<DType>(param.lr ) *
408
- (grad / (F<square_root>(state_n) +
409
- scalar<DType>(param.epsilon ))),
407
+ scalar<DType>(param.lr ) *
408
+ (grad /
409
+ (F<square_root>(state_n +
410
+ scalar<DType>(param.epsilon )))),
410
411
DType (param.clip_weights )));
411
412
} else {
412
413
Assign (out, req[0 ], weight -
413
- scalar<DType>(param.lr ) *
414
- (grad / (F<square_root>(state_n) +
415
- scalar<DType>(param.epsilon ))));
414
+ scalar<DType>(param.lr ) *
415
+ (grad /
416
+ (F<square_root>(state_n +
417
+ scalar<DType>(param.epsilon )))));
416
418
}
417
419
}
418
420
});
0 commit comments