From 0da1d7a1d1f44ca466883253d4f4ee5863c686f1 Mon Sep 17 00:00:00 2001 From: ChaosCodes Date: Sun, 5 Nov 2023 11:34:28 +0800 Subject: [PATCH] fix bugs --- lit_gpt/speed_monitor.py | 3 ++- pretrain/tinyllama.py | 5 +++-- pretrain/tinyllama_code.py | 5 +++-- scripts/prepare_slimpajama.py | 3 ++- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lit_gpt/speed_monitor.py b/lit_gpt/speed_monitor.py index fa81b18..178e0d7 100644 --- a/lit_gpt/speed_monitor.py +++ b/lit_gpt/speed_monitor.py @@ -221,6 +221,7 @@ def on_train_batch_end( samples: int, # total samples seen (per device) train_elapsed: float, # total training time (seconds) world_size: int, + step_count: int, flops_per_batch: Optional[int] = None, # (per device) lengths: Optional[int] = None, # total length of the samples seen (per device) train_loss: Optional[float] = None, @@ -291,7 +292,7 @@ def on_train_batch_end( } ) if self.iter % self.log_iter_interval == 0: - self.log_dict(metrics, self.iter//self.log_iter_interval) + self.log_dict(metrics, step_count) def eval_end(self, eval_elapsed: float): self.total_eval_wct += eval_elapsed # seconds diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py index d238a98..cc627b9 100644 --- a/pretrain/tinyllama.py +++ b/pretrain/tinyllama.py @@ -249,6 +249,7 @@ def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): t1 - total_t0, # this assumes that device FLOPs are the same and that all devices have the same batch size fabric.world_size, + state["step_count"], flops_per_batch=estimated_flops, lengths=total_lengths, train_loss = loss.item() @@ -264,8 +265,8 @@ def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): t1 = time.perf_counter() - t0 monitor.eval_end(t1) fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") - fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) - fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) fabric.barrier() if not is_accumulating and state["step_count"] % save_step_interval == 0: checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" diff --git a/pretrain/tinyllama_code.py b/pretrain/tinyllama_code.py index 24a526c..1a68e15 100644 --- a/pretrain/tinyllama_code.py +++ b/pretrain/tinyllama_code.py @@ -253,6 +253,7 @@ def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): t1 - total_t0, # this assumes that device FLOPs are the same and that all devices have the same batch size fabric.world_size, + state["step_count"], flops_per_batch=estimated_flops, lengths=total_lengths, train_loss = loss.item() @@ -268,8 +269,8 @@ def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): t1 = time.perf_counter() - t0 monitor.eval_end(t1) fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") - fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) - fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) fabric.barrier() if not is_accumulating and state["step_count"] % save_step_interval == 0: checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" diff --git a/scripts/prepare_slimpajama.py b/scripts/prepare_slimpajama.py index 24ec050..50231ff 100644 --- a/scripts/prepare_slimpajama.py +++ b/scripts/prepare_slimpajama.py @@ -66,7 +66,8 @@ def prepare_full( text_ids = tokenizer.encode(text) builder.add_array(np.array(text_ids, dtype=builder.dtype)) - builder.write_reminder() + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() def prepare(