Skip to content

Commit 61ec9f7

Browse files
ver217flybird11111
authored andcommitted
[kernel] support pure fp16 for cpu adam and update gemini optim tests (hpcaitech#4921)
* [kernel] support pure fp16 for cpu adam (hpcaitech#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (hpcaitech#4919) * [kernel] fix cpu adam * [test] update gemini optim test
1 parent 52707c6 commit 61ec9f7

File tree

8 files changed

+148
-136
lines changed

8 files changed

+148
-136
lines changed

colossalai/kernel/cuda_native/csrc/cpu_adam.cpp

+94-107
Large diffs are not rendered by default.

colossalai/kernel/cuda_native/csrc/cpu_adam.h

+30-11
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ SOFTWARE
5050
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
5151
#define SIMD_LOAD_HALF(x) \
5252
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
53-
#define SIMD_STORE_HALF(x, d) \
54-
_mm256_store_ps( \
55-
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
53+
#define SIMD_STORE_HALF(x, d) \
54+
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
55+
d, _MM_FROUND_TO_NEAREST_INT)))
5656

5757
#elif defined(__AVX256__) or defined(__AVX2__)
5858
#define SIMD_WIDTH 8
@@ -66,9 +66,9 @@ SOFTWARE
6666
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
6767
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
6868
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
69-
#define SIMD_STORE_HALF(x, d) \
70-
_mm_store_ps( \
71-
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
69+
#define SIMD_STORE_HALF(x, d) \
70+
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
71+
d, _MM_FROUND_TO_NEAREST_INT)))
7272

7373
#endif
7474

@@ -83,11 +83,12 @@ union AVX_Data {
8383

8484
#endif
8585

86-
#define STEP(SPAN) \
87-
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
88-
float *_exp_avg_sq, size_t _param_size, \
89-
bool param_half_precision = false, \
90-
bool grad_half_precision = false, float loss_scale = -1);
86+
#define STEP(SPAN) \
87+
void Step_##SPAN( \
88+
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
89+
size_t _param_size, bool param_half_precision = false, \
90+
bool grad_half_precision = false, bool momentum_half_precision = false, \
91+
bool variance_half_precision = false, float loss_scale = -1);
9192

9293
class Adam_Optimizer {
9394
public:
@@ -141,6 +142,24 @@ class Adam_Optimizer {
141142
}
142143
}
143144

145+
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
146+
AVX_Data &data) {
147+
if (is_half) {
148+
data.data = SIMD_LOAD_HALF(h_ptr);
149+
} else {
150+
data.data = SIMD_LOAD(ptr);
151+
}
152+
}
153+
154+
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
155+
AVX_Data &data) {
156+
if (is_half) {
157+
SIMD_STORE_HALF(h_ptr, data.data);
158+
} else {
159+
SIMD_STORE(ptr, data.data);
160+
}
161+
}
162+
144163
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
145164
float weight_decay, bool bias_correction, torch::Tensor &params,
146165
torch::Tensor &grads, torch::Tensor &exp_avg,

colossalai/nn/optimizer/cpu_adam.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1):
146146
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
147147
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
148148
self._pre_update(p, "exp_avg", "exp_avg_sq")
149-
# FIXME(ver217): CPU adam kernel only supports fp32 states now
150-
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
149+
if p.grad.dtype is torch.bfloat16:
151150
# cpu adam kernel does not support bf16 now
152151
bias_correction1 = 1 - beta1 ** state["step"]
153152
bias_correction2 = 1 - beta2 ** state["step"]

colossalai/nn/optimizer/hybrid_adam.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1):
122122
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
123123
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
124124
self._pre_update(p, "exp_avg", "exp_avg_sq")
125-
# FIXME(ver217): CPU adam kernel only supports fp32 states now
126-
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
125+
if p.grad.dtype is torch.bfloat16:
127126
# cpu adam kernel does not support bf16 now
128127
bias_correction1 = 1 - beta1 ** state["step"]
129128
bias_correction2 = 1 - beta2 ** state["step"]

tests/test_optimizer/test_adam_kernel.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
_FUSED_ALLOWED_P_G_TYPES = [
1414
(torch.float, torch.half),
1515
(torch.float, torch.float),
16-
(torch.half, torch.float),
1716
(torch.half, torch.half),
18-
(torch.bfloat16, torch.float),
1917
(torch.float, torch.bfloat16),
2018
(torch.bfloat16, torch.bfloat16),
2119
]
2220

2321
_CPU_ALLOWED_P_G_TYPES = [
2422
(torch.float, torch.half),
2523
(torch.float, torch.float),
26-
(torch.half, torch.float),
2724
(torch.half, torch.half),
2825
]
2926

@@ -138,8 +135,8 @@ def check_adam_kernel(
138135
master_exp_avg_sq = torch.zeros_like(master_p)
139136
p = master_p.clone().to(p_dtype)
140137
g = master_g.clone().to(g_dtype)
141-
exp_avg = master_exp_avg.clone()
142-
exp_avg_sq = master_exp_avg_sq.clone()
138+
exp_avg = master_exp_avg.clone().to(p_dtype)
139+
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
143140

144141
for step in range(1, 1 + n_steps):
145142
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)

tests/test_optimizer/test_adam_optim.py

-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
(torch.float, torch.float), # pure fp32
2222
(torch.float, torch.half), # fp16 amp
2323
(torch.float, torch.bfloat16), # bfloat16 amp
24-
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
25-
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
2624
]
2725

2826
N_STEPS = 3

tests/test_zero/test_gemini/test_grad_clip.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
5252

5353
@parameterize("placement_config", PLACEMENT_CONFIGS)
5454
@parameterize("model_name", ["gpt2"])
55-
def exam_grad_clipping(placement_config, model_name: str):
55+
@parameterize("master_weights", [True, False])
56+
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
5657
set_seed(1912)
5758
get_components_func = non_distributed_component_funcs.get_callable(model_name)
5859
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
8283
chunk_config_dict=config_dict,
8384
chunk_init_device=init_device,
8485
pin_memory=True,
86+
master_weights=master_weights,
8587
**placement_config,
8688
)
8789

@@ -103,15 +105,19 @@ def exam_grad_clipping(placement_config, model_name: str):
103105

104106
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
105107
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
106-
assert_close(torch_loss, loss)
108+
109+
# as no master weights leads to error accumulation, we don't check the loss
110+
if master_weights:
111+
assert_close(torch_loss, loss)
107112

108113
import apex.amp as apex_amp
109114

110115
torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0)
111116
torch_optim.step()
112117
zero_optim.step()
113118

114-
check_param(model, torch_model)
119+
if master_weights:
120+
check_param(model, torch_model)
115121

116122

117123
def run_dist(rank, world_size, port):

tests/test_zero/test_gemini/test_optim.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
7070
@parameterize("placement_config", PLACEMENT_CONFIGS)
7171
@parameterize("model_name", TEST_MODELS)
7272
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
73-
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
73+
@parameterize("master_weights", [True, False])
74+
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
7475
set_seed(42)
7576
get_components_func = non_distributed_component_funcs.get_callable(model_name)
7677
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
7778

7879
torch_model = model_builder().cuda()
80+
# apex no master weights leads to nan, so we don't use it
7981
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
8082
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
8183
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
@@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
9092
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
9193
config_dict[world_size]["chunk_size"] = 5000
9294
config_dict[world_size]["keep_gathered"] = False
93-
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
95+
model = GeminiDDP(
96+
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
97+
)
9498

9599
optimizer = HybridAdam(model.parameters(), lr=1e-3)
96100
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
@@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
109113

110114
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
111115
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
112-
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
116+
# as no master weights leads to error accumulation, we don't check the loss
117+
if master_weights:
118+
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
113119

114120
zero_optim.step()
115121
torch_optim.step()
116122

117-
check_param(model, torch_model, mixed_precision)
123+
if master_weights:
124+
check_param(model, torch_model, mixed_precision)
118125

119126

120127
@parameterize("placement_config", PLACEMENT_CONFIGS)

0 commit comments

Comments
 (0)