We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 236aa92 commit 6fde13bCopy full SHA for 6fde13b
train.py
@@ -252,11 +252,6 @@ def loss_fn(pred, labels):
252
for m in model_parts
253
]
254
255
- # for ease of testing TP in lieu of FSDP
256
- if job_config.training.tensor_parallel_degree == world_size:
257
- for model in model_parts:
258
- model.to(torch.bfloat16)
259
-
260
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
261
for model in model_parts:
262
model.to_empty(device=init_device)
0 commit comments