Skip to content

Commit 007430e

Browse files
Camille7777flybird11111
authored andcommitted
[llama] fix neftune & pbar with start_step (hpcaitech#5364)
1 parent 1f2b457 commit 007430e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def unwrap(model):
1919
if hasattr(model, "module"):
20-
return unwrap_model(model.module)
20+
return model.unwrap()
2121
else:
2222
return model
2323

applications/Colossal-LLaMA-2/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,9 @@ def main() -> None:
329329

330330
for epoch in range(start_epoch, args.num_epochs):
331331
dataloader.sampler.set_epoch(epoch=epoch)
332-
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
332+
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
333333
total_loss = torch.tensor(0.0, device=get_current_device())
334-
for step, batch in enumerate(dataloader):
334+
for step, batch in enumerate(dataloader, start=start_step):
335335
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
336336

337337
batch_output = model(**batch)

0 commit comments

Comments
 (0)