Skip to content

Commit

Permalink
Evaluation progress bar improvements, adding a outer wrapping bar tha…
Browse files Browse the repository at this point in the history
…t shows total progress and the number of samples in each batch
  • Loading branch information
bghira committed Jan 28, 2025
1 parent f5ae4af commit 7e8b4c5
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,12 @@ def would_evaluate(self, training_state: dict):

return False

def total_eval_batches(self):
"""
Return the total number of eval batches across all eval datasets.
"""
return sum([len(x["sampler"]) for _, x in StateTracker.get_data_backends(_type="eval").items()])

def execute_eval(
self, prepare_batch, model_predict, calculate_loss, get_prediction_target
):
Expand All @@ -1818,6 +1824,14 @@ def execute_eval(
accumulated_eval_losses = {}
eval_batch = True
evaluated_sample_count = 0
total_batches = self.total_eval_batches()
print(f"Working on {total_batches} evaluation batches.")
main_progress_bar = tqdm(
total=total_batches,
desc="Calculate validation loss",
position=2,
leave=True,
)
while eval_batch is not False:
try:
evaluated_sample_count += 1
Expand All @@ -1838,12 +1852,14 @@ def execute_eval(
prepared_eval_batch = prepare_batch(eval_batch)
if "latents" not in prepared_eval_batch:
raise ValueError(f"Error calculating eval batch.")
bsz = prepared_eval_batch["latents"].shape[0]
sample_text_str = "samples" if bsz > 1 else "sample"
with torch.no_grad():
eval_timestep_list = range(0, 1000, 25)
for eval_timestep in tqdm(
eval_timestep_list.__reversed__(),
total=len(eval_timestep_list),
desc="Calculate eval loss",
desc=f"Evaluating batch of {bsz} {sample_text_str}",
position=1,
leave=False,
):
Expand All @@ -1869,6 +1885,7 @@ def execute_eval(
apply_conditioning_mask=False,
)
accumulated_eval_losses[eval_timestep].append(eval_loss)
main_progress_bar.update(1)
torch.set_rng_state(training_random_states)
return accumulated_eval_losses

Expand Down

0 comments on commit 7e8b4c5

Please sign in to comment.