Skip to content

Commit

Permalink
Add option to view results from existing csv file
Browse files Browse the repository at this point in the history
  • Loading branch information
gkielian committed Apr 17, 2024
1 parent a55cf61 commit 8bd9346
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions inspect_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,28 @@ def find_ckpt_files(directory):

def main():
parser = argparse.ArgumentParser(description='Extract best validation loss and iteration number from PyTorch checkpoint files.')
parser.add_argument('directory', type=str, help='Path to the directory containing the checkpoint files.')
parser.add_argument('--directory', type=str, help='Path to the directory containing the checkpoint files.')
parser.add_argument('--csv_file', type=str, help='Path to the CSV file containing the checkpoint data.')
parser.add_argument('--sort', type=str, choices=['path', 'loss', 'iter'], default='path', help='Sort the table by checkpoint file path, best validation loss, or iteration number.')
parser.add_argument('--reverse', action='store_true', help='Reverse the sort order.')
parser.add_argument('--output', type=str, help='Path to the output CSV file.')
args = parser.parse_args()

ckpt_files = find_ckpt_files(args.directory)

# Extract the best validation loss and iteration number for each checkpoint file
ckpt_data = [(ckpt_file, *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files]
if args.directory:
ckpt_files = find_ckpt_files(args.directory)

# Extract the best validation loss and iteration number for each checkpoint file
ckpt_data = [(ckpt_file, *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files]
elif args.csv_file:
ckpt_data = []
with open(args.csv_file, 'r') as csvfile:
csv_reader = csv.reader(csvfile)
next(csv_reader) # Skip the header row
for row in csv_reader:
ckpt_data.append((row[0], float(row[1]), int(row[2])))
else:
print("Please provide either a directory or a CSV file.")
return

# Sort the data based on the specified sort option
if args.sort == 'path':
Expand Down

0 comments on commit 8bd9346

Please sign in to comment.