Skip to content

Commit 9c88c87

Browse files
authored
[kernel] support pure fp16 for cpu adam (#4896)
1 parent 83b52c5 commit 9c88c87

File tree

6 files changed

+135
-40
lines changed

6 files changed

+135
-40
lines changed

colossalai/kernel/cuda_native/csrc/cpu_adam.cpp

+125-24
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ SOFTWARE
3535
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
3636
float *_exp_avg_sq, size_t _param_size,
3737
bool param_half_precision, bool grad_half_precision,
38-
float loss_scale) {
38+
bool momentum_half_precision,
39+
bool variance_half_precision, float loss_scale) {
3940
size_t rounded_size = 0;
4041

4142
float betta1_minus1 = 1 - _betta1;
@@ -45,13 +46,21 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
4546

4647
__half *params_cast_h = NULL;
4748
__half *grads_cast_h = NULL;
49+
__half *momentum_cast_h = NULL;
50+
__half *variance_cast_h = NULL;
4851

4952
if (param_half_precision) {
5053
params_cast_h = reinterpret_cast<__half *>(_params);
5154
}
5255
if (grad_half_precision) {
5356
grads_cast_h = reinterpret_cast<__half *>(grads);
5457
}
58+
if (momentum_half_precision) {
59+
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
60+
}
61+
if (variance_half_precision) {
62+
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
63+
}
5564

5665
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
5766
AVX_Data betta1_4;
@@ -98,10 +107,18 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
98107
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
99108
}
100109
AVX_Data momentum_4;
101-
momentum_4.data = SIMD_LOAD(_exp_avg + i);
110+
if (momentum_half_precision) {
111+
momentum_4.data = SIMD_LOAD_HALF(momentum_cast_h + i);
112+
} else {
113+
momentum_4.data = SIMD_LOAD(_exp_avg + i);
114+
}
102115

103116
AVX_Data variance_4;
104-
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
117+
if (variance_half_precision) {
118+
variance_4.data = SIMD_LOAD_HALF(variance_cast_h + i);
119+
} else {
120+
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
121+
}
105122

106123
AVX_Data param_4;
107124
if (param_half_precision) {
@@ -135,8 +152,16 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
135152
} else {
136153
SIMD_STORE(_params + i, param_4.data);
137154
}
138-
SIMD_STORE(_exp_avg + i, momentum_4.data);
139-
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
155+
if (momentum_half_precision) {
156+
SIMD_STORE_HALF((float *)(momentum_cast_h + i), momentum_4.data);
157+
} else {
158+
SIMD_STORE(_exp_avg + i, momentum_4.data);
159+
}
160+
if (variance_half_precision) {
161+
SIMD_STORE_HALF((float *)(variance_cast_h + i), variance_4.data);
162+
} else {
163+
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
164+
}
140165
}
141166
}
142167
#endif
@@ -154,8 +179,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
154179
}
155180
float param =
156181
param_half_precision ? (float)params_cast_h[k] : _params[k];
157-
float momentum = _exp_avg[k];
158-
float variance = _exp_avg_sq[k];
182+
float momentum =
183+
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
184+
float variance = variance_half_precision ? (float)variance_cast_h[k]
185+
: _exp_avg_sq[k];
159186
if (_weight_decay > 0 && !_adamw_mode) {
160187
grad = param * _weight_decay + grad;
161188
}
@@ -178,8 +205,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
178205
params_cast_h[k] = (__half)param;
179206
else
180207
_params[k] = param;
181-
_exp_avg[k] = momentum;
182-
_exp_avg_sq[k] = variance;
208+
if (momentum_half_precision)
209+
momentum_cast_h[k] = (__half)(momentum);
210+
else
211+
_exp_avg[k] = momentum;
212+
if (variance_half_precision)
213+
variance_cast_h[k] = (__half)(variance);
214+
else
215+
_exp_avg_sq[k] = variance;
183216
}
184217
}
185218
}
@@ -188,17 +221,26 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
188221
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
189222
float *_exp_avg_sq, size_t _param_size,
190223
bool param_half_precision, bool grad_half_precision,
191-
float loss_scale) {
224+
bool momentum_half_precision,
225+
bool variance_half_precision, float loss_scale) {
192226
size_t rounded_size = 0;
193227

194228
__half *params_cast_h = NULL;
195229
__half *grads_cast_h = NULL;
230+
__half *momentum_cast_h = NULL;
231+
__half *variance_cast_h = NULL;
196232
if (param_half_precision) {
197233
params_cast_h = reinterpret_cast<__half *>(_params);
198234
}
199235
if (grad_half_precision) {
200236
grads_cast_h = reinterpret_cast<__half *>(grads);
201237
}
238+
if (momentum_half_precision) {
239+
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
240+
}
241+
if (variance_half_precision) {
242+
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
243+
}
202244

203245
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
204246
AVX_Data betta1_4;
@@ -255,8 +297,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
255297
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
256298
}
257299

258-
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
259-
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
300+
if (momentum_half_precision) {
301+
momentum_4[j].data =
302+
SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j);
303+
} else {
304+
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
305+
}
306+
if (variance_half_precision) {
307+
variance_4[j].data =
308+
SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j);
309+
} else {
310+
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
311+
}
260312

261313
if (param_half_precision) {
262314
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
@@ -291,8 +343,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
291343
} else {
292344
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
293345
}
294-
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
295-
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
346+
if (momentum_half_precision) {
347+
SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
348+
momentum_4[j].data);
349+
} else {
350+
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
351+
}
352+
if (variance_half_precision) {
353+
SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j),
354+
variance_4[j].data);
355+
} else {
356+
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
357+
}
296358
}
297359
}
298360
}
@@ -302,24 +364,37 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
302364
: _params + rounded_size),
303365
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
304366
: grads + rounded_size),
305-
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
367+
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
368+
: _exp_avg + rounded_size),
369+
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
370+
: _exp_avg_sq + rounded_size),
306371
(_param_size - rounded_size), param_half_precision,
307-
grad_half_precision, loss_scale);
372+
grad_half_precision, momentum_half_precision,
373+
variance_half_precision, loss_scale);
308374
}
309375

310376
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
311377
float *_exp_avg_sq, size_t _param_size,
312378
bool param_half_precision, bool grad_half_precision,
313-
float loss_scale) {
379+
bool momentum_half_precision,
380+
bool variance_half_precision, float loss_scale) {
314381
size_t rounded_size = 0;
315382
__half *params_cast_h = NULL;
316383
__half *grads_cast_h = NULL;
384+
__half *momentum_cast_h = NULL;
385+
__half *variance_cast_h = NULL;
317386
if (param_half_precision) {
318387
params_cast_h = reinterpret_cast<__half *>(_params);
319388
}
320389
if (grad_half_precision) {
321390
grads_cast_h = reinterpret_cast<__half *>(grads);
322391
}
392+
if (momentum_half_precision) {
393+
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
394+
}
395+
if (variance_half_precision) {
396+
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
397+
}
323398
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
324399
AVX_Data betta1_4;
325400
betta1_4.data = SIMD_SET(_betta1);
@@ -375,8 +450,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
375450
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
376451
}
377452

378-
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
379-
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
453+
if (momentum_half_precision) {
454+
momentum_4[j].data =
455+
SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j);
456+
} else {
457+
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
458+
}
459+
if (variance_half_precision) {
460+
variance_4[j].data =
461+
SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j);
462+
} else {
463+
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
464+
}
380465

381466
if (param_half_precision) {
382467
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
@@ -412,8 +497,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
412497
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
413498
}
414499

415-
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
416-
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
500+
if (momentum_half_precision) {
501+
SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
502+
momentum_4[j].data);
503+
} else {
504+
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
505+
}
506+
if (variance_half_precision) {
507+
SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j),
508+
variance_4[j].data);
509+
} else {
510+
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
511+
}
417512
}
418513
}
419514
}
@@ -423,9 +518,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
423518
: _params + rounded_size),
424519
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
425520
: grads + rounded_size),
426-
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
521+
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
522+
: _exp_avg + rounded_size),
523+
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
524+
: _exp_avg_sq + rounded_size),
427525
(_param_size - rounded_size), param_half_precision,
428-
grad_half_precision, loss_scale);
526+
grad_half_precision, momentum_half_precision,
527+
variance_half_precision, loss_scale);
429528
}
430529

431530
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
@@ -447,7 +546,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
447546
this->update_state(lr, epsilon, weight_decay, bias_correction);
448547
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
449548
params_c.numel(), (params.options().dtype() == at::kHalf),
450-
(grads.options().dtype() == at::kHalf), loss_scale);
549+
(grads.options().dtype() == at::kHalf),
550+
(exp_avg.options().dtype() == at::kHalf),
551+
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
451552
}
452553

453554
namespace py = pybind11;

colossalai/kernel/cuda_native/csrc/cpu_adam.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,12 @@ union AVX_Data {
8383

8484
#endif
8585

86-
#define STEP(SPAN) \
87-
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
88-
float *_exp_avg_sq, size_t _param_size, \
89-
bool param_half_precision = false, \
90-
bool grad_half_precision = false, float loss_scale = -1);
86+
#define STEP(SPAN) \
87+
void Step_##SPAN( \
88+
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
89+
size_t _param_size, bool param_half_precision = false, \
90+
bool grad_half_precision = false, bool momentum_half_precision = false, \
91+
bool variance_half_precision = false, float loss_scale = -1);
9192

9293
class Adam_Optimizer {
9394
public:

colossalai/nn/optimizer/cpu_adam.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1):
146146
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
147147
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
148148
self._pre_update(p, "exp_avg", "exp_avg_sq")
149-
# FIXME(ver217): CPU adam kernel only supports fp32 states now
150-
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
149+
if p.grad.dtype is torch.bfloat16:
151150
# cpu adam kernel does not support bf16 now
152151
bias_correction1 = 1 - beta1 ** state["step"]
153152
bias_correction2 = 1 - beta2 ** state["step"]

colossalai/nn/optimizer/hybrid_adam.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1):
122122
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
123123
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
124124
self._pre_update(p, "exp_avg", "exp_avg_sq")
125-
# FIXME(ver217): CPU adam kernel only supports fp32 states now
126-
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
125+
if p.grad.dtype is torch.bfloat16:
127126
# cpu adam kernel does not support bf16 now
128127
bias_correction1 = 1 - beta1 ** state["step"]
129128
bias_correction2 = 1 - beta2 ** state["step"]

tests/test_optimizer/test_adam_kernel.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
_FUSED_ALLOWED_P_G_TYPES = [
1414
(torch.float, torch.half),
1515
(torch.float, torch.float),
16-
(torch.half, torch.float),
1716
(torch.half, torch.half),
18-
(torch.bfloat16, torch.float),
1917
(torch.float, torch.bfloat16),
2018
(torch.bfloat16, torch.bfloat16),
2119
]
2220

2321
_CPU_ALLOWED_P_G_TYPES = [
2422
(torch.float, torch.half),
2523
(torch.float, torch.float),
26-
(torch.half, torch.float),
2724
(torch.half, torch.half),
2825
]
2926

@@ -138,8 +135,8 @@ def check_adam_kernel(
138135
master_exp_avg_sq = torch.zeros_like(master_p)
139136
p = master_p.clone().to(p_dtype)
140137
g = master_g.clone().to(g_dtype)
141-
exp_avg = master_exp_avg.clone()
142-
exp_avg_sq = master_exp_avg_sq.clone()
138+
exp_avg = master_exp_avg.clone().to(p_dtype)
139+
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
143140

144141
for step in range(1, 1 + n_steps):
145142
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)

tests/test_optimizer/test_adam_optim.py

-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
(torch.float, torch.float), # pure fp32
2222
(torch.float, torch.half), # fp16 amp
2323
(torch.float, torch.bfloat16), # bfloat16 amp
24-
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
25-
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
2624
]
2725

2826
N_STEPS = 3

0 commit comments

Comments
 (0)