Skip to content

Commit

Permalink
Modify checkpoint inspector w/training completion
Browse files Browse the repository at this point in the history
Add feature to monitor training completion.
  • Loading branch information
gkielian committed Apr 27, 2024
1 parent f859b41 commit 7676f4d
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions inspect_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def get_best_val_loss_and_iter_num(checkpoint_file):
best_val_loss = checkpoint['best_val_loss']
iter_num = checkpoint['iter_num']

return best_val_loss, iter_num
training_nan = checkpoint['nan']
training_nan_iter = checkpoint['nan_iter_num']

return best_val_loss, iter_num, training_nan, training_nan_iter

def find_ckpt_files(directory, path_regex=None):
"""
Expand Down Expand Up @@ -102,24 +105,26 @@ def main():
console = Console()

# Determine the maximum length of the checkpoint file paths
max_path_length = max(len(ckpt_file) for ckpt_file, _, _ in ckpt_data)
max_path_length = max(len(ckpt_file) for ckpt_file, _, _, _, _ in ckpt_data)

table = Table(show_header=True, header_style="bold magenta")
table.add_column("Checkpoint File", style="dim", width=max_path_length + 2)
table.add_column("Best Validation Loss", justify="right")
table.add_column("Iteration Number", justify="right")
table = Table(show_header=True, header_style="bold blue")
table.add_column("Ckpt File", style="", width=max_path_length + 2)
table.add_column("Best Val Loss", justify="right")
table.add_column("Iter Num", justify="right")
table.add_column("NaN Result", justify="right")
table.add_column("NaN Iter Num", justify="right")

if args.output:
with open(args.output, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(["Checkpoint File", "Best Validation Loss", "Iteration Number"])
for ckpt_file, best_val_loss, iter_num in ckpt_data:
table.add_row(ckpt_file, f"{best_val_loss:.4f}", str(iter_num))
csv_writer.writerow([ckpt_file, f"{best_val_loss:.4f}", str(iter_num)])
csv_writer.writerow(["Checkpoint File", "Best Validation Loss", "Iteration Number", "NaN", "Nan Iter"])
for ckpt_file, best_val_loss, iter_num, training_nan, training_nan_iter in ckpt_data:
table.add_row(ckpt_file, f"{best_val_loss:.4f}", str(iter_num), str(training_nan), str(training_nan_iter))
csv_writer.writerow([ckpt_file, f"{best_val_loss:.4f}", str(iter_num), str(training_nan), str(training_nan_iter)])
print(f"Results exported to {args.output}")
else:
for ckpt_file, best_val_loss, iter_num in ckpt_data:
table.add_row(ckpt_file, f"{best_val_loss:.4f}", str(iter_num))
for ckpt_file, best_val_loss, iter_num, training_nan, training_nan_iter in ckpt_data:
table.add_row(ckpt_file, f"{best_val_loss:.4f}", str(iter_num), str(training_nan), str(training_nan_iter))

console.print(table)

Expand Down

0 comments on commit 7676f4d

Please sign in to comment.