Commit cbf0112 1 parent 6257d0b commit cbf0112 Copy full SHA for cbf0112
File tree 3 files changed +6
-6
lines changed
applications/Chat/examples
3 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -65,8 +65,8 @@ def main(args):
65
65
if args .rm_path is not None :
66
66
reward_model .load_state_dict (state_dict , strict = False )
67
67
68
- initial_model .to (torch .float16 ).to (torch .cuda .current_device ())
69
- reward_model .to (torch .float16 ).to (torch .cuda .current_device ())
68
+ initial_model .to (torch .bfloat16 ).to (torch .cuda .current_device ())
69
+ reward_model .to (torch .bfloat16 ).to (torch .cuda .current_device ())
70
70
71
71
if args .model == 'gpt2' :
72
72
actor = GPTActor (pretrained = args .pretrain , lora_rank = args .lora_rank )
@@ -95,8 +95,8 @@ def main(args):
95
95
del state_dict
96
96
97
97
if args .strategy != 'colossalai_gemini' :
98
- critic .to (torch .float16 ).to (torch .cuda .current_device ())
99
- actor .to (torch .float16 ).to (torch .cuda .current_device ())
98
+ critic .to (torch .bfloat16 ).to (torch .cuda .current_device ())
99
+ actor .to (torch .bfloat16 ).to (torch .cuda .current_device ())
100
100
101
101
# configure optimizer
102
102
if args .strategy .startswith ('colossalai' ):
Original file line number Diff line number Diff line change @@ -45,7 +45,7 @@ def train(args):
45
45
else :
46
46
raise ValueError (f'Unsupported model "{ args .model } "' )
47
47
48
- model .to (torch .float16 ).to (torch .cuda .current_device ())
48
+ model .to (torch .bfloat16 ).to (torch .cuda .current_device ())
49
49
50
50
if args .model_path is not None :
51
51
state_dict = torch .load (args .model_path )
Original file line number Diff line number Diff line change @@ -60,7 +60,7 @@ def train(args):
60
60
else :
61
61
raise ValueError (f'Unsupported model "{ args .model } "' )
62
62
63
- model .to (torch .float16 ).to (torch .cuda .current_device ())
63
+ model .to (torch .bfloat16 ).to (torch .cuda .current_device ())
64
64
65
65
# configure tokenizer
66
66
if args .model == 'gpt2' :
You can’t perform that action at this time.
0 commit comments