diff --git a/inspect_ckpts.py b/inspect_ckpts.py index f6eebaeb70..6826847f42 100644 --- a/inspect_ckpts.py +++ b/inspect_ckpts.py @@ -43,16 +43,28 @@ def find_ckpt_files(directory): def main(): parser = argparse.ArgumentParser(description='Extract best validation loss and iteration number from PyTorch checkpoint files.') - parser.add_argument('directory', type=str, help='Path to the directory containing the checkpoint files.') + parser.add_argument('--directory', type=str, help='Path to the directory containing the checkpoint files.') + parser.add_argument('--csv_file', type=str, help='Path to the CSV file containing the checkpoint data.') parser.add_argument('--sort', type=str, choices=['path', 'loss', 'iter'], default='path', help='Sort the table by checkpoint file path, best validation loss, or iteration number.') parser.add_argument('--reverse', action='store_true', help='Reverse the sort order.') parser.add_argument('--output', type=str, help='Path to the output CSV file.') args = parser.parse_args() - ckpt_files = find_ckpt_files(args.directory) - - # Extract the best validation loss and iteration number for each checkpoint file - ckpt_data = [(ckpt_file, *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files] + if args.directory: + ckpt_files = find_ckpt_files(args.directory) + + # Extract the best validation loss and iteration number for each checkpoint file + ckpt_data = [(ckpt_file, *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files] + elif args.csv_file: + ckpt_data = [] + with open(args.csv_file, 'r') as csvfile: + csv_reader = csv.reader(csvfile) + next(csv_reader) # Skip the header row + for row in csv_reader: + ckpt_data.append((row[0], float(row[1]), int(row[2]))) + else: + print("Please provide either a directory or a CSV file.") + return # Sort the data based on the specified sort option if args.sort == 'path':