Skip to content

Commit cbf0112

Browse files
committed
fix: use bf16 to avoid overflow
1 parent 6257d0b commit cbf0112

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

applications/Chat/examples/train_prompts.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def main(args):
6565
if args.rm_path is not None:
6666
reward_model.load_state_dict(state_dict, strict=False)
6767

68-
initial_model.to(torch.float16).to(torch.cuda.current_device())
69-
reward_model.to(torch.float16).to(torch.cuda.current_device())
68+
initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
69+
reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
7070

7171
if args.model == 'gpt2':
7272
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
@@ -95,8 +95,8 @@ def main(args):
9595
del state_dict
9696

9797
if args.strategy != 'colossalai_gemini':
98-
critic.to(torch.float16).to(torch.cuda.current_device())
99-
actor.to(torch.float16).to(torch.cuda.current_device())
98+
critic.to(torch.bfloat16).to(torch.cuda.current_device())
99+
actor.to(torch.bfloat16).to(torch.cuda.current_device())
100100

101101
# configure optimizer
102102
if args.strategy.startswith('colossalai'):

applications/Chat/examples/train_reward_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def train(args):
4545
else:
4646
raise ValueError(f'Unsupported model "{args.model}"')
4747

48-
model.to(torch.float16).to(torch.cuda.current_device())
48+
model.to(torch.bfloat16).to(torch.cuda.current_device())
4949

5050
if args.model_path is not None:
5151
state_dict = torch.load(args.model_path)

applications/Chat/examples/train_sft.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def train(args):
6060
else:
6161
raise ValueError(f'Unsupported model "{args.model}"')
6262

63-
model.to(torch.float16).to(torch.cuda.current_device())
63+
model.to(torch.bfloat16).to(torch.cuda.current_device())
6464

6565
# configure tokenizer
6666
if args.model == 'gpt2':

0 commit comments

Comments
 (0)