@@ -35,7 +35,8 @@ SOFTWARE
35
35
void Adam_Optimizer::Step_1 (float *_params, float *grads, float *_exp_avg,
36
36
float *_exp_avg_sq, size_t _param_size,
37
37
bool param_half_precision, bool grad_half_precision,
38
- float loss_scale) {
38
+ bool momentum_half_precision,
39
+ bool variance_half_precision, float loss_scale) {
39
40
size_t rounded_size = 0 ;
40
41
41
42
float betta1_minus1 = 1 - _betta1;
@@ -45,13 +46,21 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
45
46
46
47
__half *params_cast_h = NULL ;
47
48
__half *grads_cast_h = NULL ;
49
+ __half *momentum_cast_h = NULL ;
50
+ __half *variance_cast_h = NULL ;
48
51
49
52
if (param_half_precision) {
50
53
params_cast_h = reinterpret_cast <__half *>(_params);
51
54
}
52
55
if (grad_half_precision) {
53
56
grads_cast_h = reinterpret_cast <__half *>(grads);
54
57
}
58
+ if (momentum_half_precision) {
59
+ momentum_cast_h = reinterpret_cast <__half *>(_exp_avg);
60
+ }
61
+ if (variance_half_precision) {
62
+ variance_cast_h = reinterpret_cast <__half *>(_exp_avg_sq);
63
+ }
55
64
56
65
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
57
66
AVX_Data betta1_4;
@@ -98,10 +107,18 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
98
107
grad_4.data = SIMD_DIV (grad_4.data , loss_scale_vec.data );
99
108
}
100
109
AVX_Data momentum_4;
101
- momentum_4.data = SIMD_LOAD (_exp_avg + i);
110
+ if (momentum_half_precision) {
111
+ momentum_4.data = SIMD_LOAD_HALF (momentum_cast_h + i);
112
+ } else {
113
+ momentum_4.data = SIMD_LOAD (_exp_avg + i);
114
+ }
102
115
103
116
AVX_Data variance_4;
104
- variance_4.data = SIMD_LOAD (_exp_avg_sq + i);
117
+ if (variance_half_precision) {
118
+ variance_4.data = SIMD_LOAD_HALF (variance_cast_h + i);
119
+ } else {
120
+ variance_4.data = SIMD_LOAD (_exp_avg_sq + i);
121
+ }
105
122
106
123
AVX_Data param_4;
107
124
if (param_half_precision) {
@@ -135,8 +152,16 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
135
152
} else {
136
153
SIMD_STORE (_params + i, param_4.data );
137
154
}
138
- SIMD_STORE (_exp_avg + i, momentum_4.data );
139
- SIMD_STORE (_exp_avg_sq + i, variance_4.data );
155
+ if (momentum_half_precision) {
156
+ SIMD_STORE_HALF ((float *)(momentum_cast_h + i), momentum_4.data );
157
+ } else {
158
+ SIMD_STORE (_exp_avg + i, momentum_4.data );
159
+ }
160
+ if (variance_half_precision) {
161
+ SIMD_STORE_HALF ((float *)(variance_cast_h + i), variance_4.data );
162
+ } else {
163
+ SIMD_STORE (_exp_avg_sq + i, variance_4.data );
164
+ }
140
165
}
141
166
}
142
167
#endif
@@ -154,8 +179,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
154
179
}
155
180
float param =
156
181
param_half_precision ? (float )params_cast_h[k] : _params[k];
157
- float momentum = _exp_avg[k];
158
- float variance = _exp_avg_sq[k];
182
+ float momentum =
183
+ momentum_half_precision ? (float )momentum_cast_h[k] : _exp_avg[k];
184
+ float variance = variance_half_precision ? (float )variance_cast_h[k]
185
+ : _exp_avg_sq[k];
159
186
if (_weight_decay > 0 && !_adamw_mode) {
160
187
grad = param * _weight_decay + grad;
161
188
}
@@ -178,8 +205,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
178
205
params_cast_h[k] = (__half)param;
179
206
else
180
207
_params[k] = param;
181
- _exp_avg[k] = momentum;
182
- _exp_avg_sq[k] = variance;
208
+ if (momentum_half_precision)
209
+ momentum_cast_h[k] = (__half)(momentum);
210
+ else
211
+ _exp_avg[k] = momentum;
212
+ if (variance_half_precision)
213
+ variance_cast_h[k] = (__half)(variance);
214
+ else
215
+ _exp_avg_sq[k] = variance;
183
216
}
184
217
}
185
218
}
@@ -188,17 +221,26 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
188
221
void Adam_Optimizer::Step_4 (float *_params, float *grads, float *_exp_avg,
189
222
float *_exp_avg_sq, size_t _param_size,
190
223
bool param_half_precision, bool grad_half_precision,
191
- float loss_scale) {
224
+ bool momentum_half_precision,
225
+ bool variance_half_precision, float loss_scale) {
192
226
size_t rounded_size = 0 ;
193
227
194
228
__half *params_cast_h = NULL ;
195
229
__half *grads_cast_h = NULL ;
230
+ __half *momentum_cast_h = NULL ;
231
+ __half *variance_cast_h = NULL ;
196
232
if (param_half_precision) {
197
233
params_cast_h = reinterpret_cast <__half *>(_params);
198
234
}
199
235
if (grad_half_precision) {
200
236
grads_cast_h = reinterpret_cast <__half *>(grads);
201
237
}
238
+ if (momentum_half_precision) {
239
+ momentum_cast_h = reinterpret_cast <__half *>(_exp_avg);
240
+ }
241
+ if (variance_half_precision) {
242
+ variance_cast_h = reinterpret_cast <__half *>(_exp_avg_sq);
243
+ }
202
244
203
245
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
204
246
AVX_Data betta1_4;
@@ -255,8 +297,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
255
297
grad_4[j].data = SIMD_DIV (grad_4[j].data , loss_scale_vec.data );
256
298
}
257
299
258
- momentum_4[j].data = SIMD_LOAD (_exp_avg + i + SIMD_WIDTH * j);
259
- variance_4[j].data = SIMD_LOAD (_exp_avg_sq + i + SIMD_WIDTH * j);
300
+ if (momentum_half_precision) {
301
+ momentum_4[j].data =
302
+ SIMD_LOAD_HALF (momentum_cast_h + i + SIMD_WIDTH * j);
303
+ } else {
304
+ momentum_4[j].data = SIMD_LOAD (_exp_avg + i + SIMD_WIDTH * j);
305
+ }
306
+ if (variance_half_precision) {
307
+ variance_4[j].data =
308
+ SIMD_LOAD_HALF (variance_cast_h + i + SIMD_WIDTH * j);
309
+ } else {
310
+ variance_4[j].data = SIMD_LOAD (_exp_avg_sq + i + SIMD_WIDTH * j);
311
+ }
260
312
261
313
if (param_half_precision) {
262
314
param_4[j].data = SIMD_LOAD_HALF (params_cast_h + i + SIMD_WIDTH * j);
@@ -291,8 +343,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
291
343
} else {
292
344
SIMD_STORE (_params + i + SIMD_WIDTH * j, param_4[j].data );
293
345
}
294
- SIMD_STORE (_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data );
295
- SIMD_STORE (_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data );
346
+ if (momentum_half_precision) {
347
+ SIMD_STORE_HALF ((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
348
+ momentum_4[j].data );
349
+ } else {
350
+ SIMD_STORE (_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data );
351
+ }
352
+ if (variance_half_precision) {
353
+ SIMD_STORE_HALF ((float *)(variance_cast_h + i + SIMD_WIDTH * j),
354
+ variance_4[j].data );
355
+ } else {
356
+ SIMD_STORE (_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data );
357
+ }
296
358
}
297
359
}
298
360
}
@@ -302,24 +364,37 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
302
364
: _params + rounded_size),
303
365
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
304
366
: grads + rounded_size),
305
- (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
367
+ (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
368
+ : _exp_avg + rounded_size),
369
+ (variance_half_precision ? (float *)(variance_cast_h + rounded_size)
370
+ : _exp_avg_sq + rounded_size),
306
371
(_param_size - rounded_size), param_half_precision,
307
- grad_half_precision, loss_scale);
372
+ grad_half_precision, momentum_half_precision,
373
+ variance_half_precision, loss_scale);
308
374
}
309
375
310
376
void Adam_Optimizer::Step_8 (float *_params, float *grads, float *_exp_avg,
311
377
float *_exp_avg_sq, size_t _param_size,
312
378
bool param_half_precision, bool grad_half_precision,
313
- float loss_scale) {
379
+ bool momentum_half_precision,
380
+ bool variance_half_precision, float loss_scale) {
314
381
size_t rounded_size = 0 ;
315
382
__half *params_cast_h = NULL ;
316
383
__half *grads_cast_h = NULL ;
384
+ __half *momentum_cast_h = NULL ;
385
+ __half *variance_cast_h = NULL ;
317
386
if (param_half_precision) {
318
387
params_cast_h = reinterpret_cast <__half *>(_params);
319
388
}
320
389
if (grad_half_precision) {
321
390
grads_cast_h = reinterpret_cast <__half *>(grads);
322
391
}
392
+ if (momentum_half_precision) {
393
+ momentum_cast_h = reinterpret_cast <__half *>(_exp_avg);
394
+ }
395
+ if (variance_half_precision) {
396
+ variance_cast_h = reinterpret_cast <__half *>(_exp_avg_sq);
397
+ }
323
398
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
324
399
AVX_Data betta1_4;
325
400
betta1_4.data = SIMD_SET (_betta1);
@@ -375,8 +450,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
375
450
grad_4[j].data = SIMD_DIV (grad_4[j].data , loss_scale_vec.data );
376
451
}
377
452
378
- momentum_4[j].data = SIMD_LOAD (_exp_avg + i + SIMD_WIDTH * j);
379
- variance_4[j].data = SIMD_LOAD (_exp_avg_sq + i + SIMD_WIDTH * j);
453
+ if (momentum_half_precision) {
454
+ momentum_4[j].data =
455
+ SIMD_LOAD_HALF (momentum_cast_h + i + SIMD_WIDTH * j);
456
+ } else {
457
+ momentum_4[j].data = SIMD_LOAD (_exp_avg + i + SIMD_WIDTH * j);
458
+ }
459
+ if (variance_half_precision) {
460
+ variance_4[j].data =
461
+ SIMD_LOAD_HALF (variance_cast_h + i + SIMD_WIDTH * j);
462
+ } else {
463
+ variance_4[j].data = SIMD_LOAD (_exp_avg_sq + i + SIMD_WIDTH * j);
464
+ }
380
465
381
466
if (param_half_precision) {
382
467
param_4[j].data = SIMD_LOAD_HALF (params_cast_h + i + SIMD_WIDTH * j);
@@ -412,8 +497,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
412
497
SIMD_STORE (_params + i + SIMD_WIDTH * j, param_4[j].data );
413
498
}
414
499
415
- SIMD_STORE (_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data );
416
- SIMD_STORE (_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data );
500
+ if (momentum_half_precision) {
501
+ SIMD_STORE_HALF ((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
502
+ momentum_4[j].data );
503
+ } else {
504
+ SIMD_STORE (_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data );
505
+ }
506
+ if (variance_half_precision) {
507
+ SIMD_STORE_HALF ((float *)(variance_cast_h + i + SIMD_WIDTH * j),
508
+ variance_4[j].data );
509
+ } else {
510
+ SIMD_STORE (_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data );
511
+ }
417
512
}
418
513
}
419
514
}
@@ -423,9 +518,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
423
518
: _params + rounded_size),
424
519
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
425
520
: grads + rounded_size),
426
- (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
521
+ (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
522
+ : _exp_avg + rounded_size),
523
+ (variance_half_precision ? (float *)(variance_cast_h + rounded_size)
524
+ : _exp_avg_sq + rounded_size),
427
525
(_param_size - rounded_size), param_half_precision,
428
- grad_half_precision, loss_scale);
526
+ grad_half_precision, momentum_half_precision,
527
+ variance_half_precision, loss_scale);
429
528
}
430
529
431
530
void Adam_Optimizer::step (size_t step, float lr, float beta1, float beta2,
@@ -447,7 +546,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
447
546
this ->update_state (lr, epsilon, weight_decay, bias_correction);
448
547
this ->Step_8 (params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
449
548
params_c.numel (), (params.options ().dtype () == at::kHalf ),
450
- (grads.options ().dtype () == at::kHalf ), loss_scale);
549
+ (grads.options ().dtype () == at::kHalf ),
550
+ (exp_avg.options ().dtype () == at::kHalf ),
551
+ (exp_avg_sq.options ().dtype () == at::kHalf ), loss_scale);
451
552
}
452
553
453
554
namespace py = pybind11;
0 commit comments