Skip to content

Commit

Permalink
add --eval_timestep_interval to reduce compute density requirements, …
Browse files Browse the repository at this point in the history
…fix RNG save/load for CUDA
  • Loading branch information
bghira committed Jan 28, 2025
1 parent 7e8b4c5 commit 4c7c968
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
17 changes: 15 additions & 2 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,14 +1679,26 @@ def get_argument_parser():
" configured in your dataloader."
),
)
parser.add_argument(
"--eval_timestep_interval",
type=int,
default=200,
help=(
"When evaluating batches, the entire 1000 timesteps may be sampled with a granularity of 1."
" To save time and reduce redundancy, a granularity of 200 is used by default."
" More granularity means more accurate charts, but it may not mean more interpretable results."
)
)
parser.add_argument(
"--num_eval_images",
type=int,
default=4,
help=(
"If possible, this many eval images will be selected from each dataset."
" This is used when training super-resolution models such as DeepFloyd Stage II,"
" which will upscale input images from the training set."
" which will upscale input images from the training set during validation."
" If using --eval_steps_interval, this will be the number of batches sampled"
" for loss calculations."
),
)
parser.add_argument(
Expand All @@ -1695,7 +1707,8 @@ def get_argument_parser():
default=None,
help=(
"When provided, only this dataset's images will be used as the eval set, to keep"
" the training and eval images split."
" the training and eval images split. This option only applies for img2img validations,"
" not validation loss calculations."
),
)
parser.add_argument(
Expand Down
21 changes: 15 additions & 6 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,14 +1825,18 @@ def execute_eval(
eval_batch = True
evaluated_sample_count = 0
total_batches = self.total_eval_batches()
print(f"Working on {total_batches} evaluation batches.")
if self.config.num_eval_images is not None:
total_batches = min(self.config.num_eval_images, total_batches)
main_progress_bar = tqdm(
total=total_batches,
desc="Calculate validation loss",
position=2,
position=0,
leave=True,
)
while eval_batch is not False:
cpu_rng_state = torch.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
while eval_batch is not False and evaluated_sample_count < total_batches:
try:
evaluated_sample_count += 1
if evaluated_sample_count > self.config.num_eval_images:
Expand All @@ -1846,7 +1850,6 @@ def execute_eval(
eval_batch = False

if eval_batch is not None and eval_batch is not False:
training_random_states = torch.get_rng_state()
# this seed is set for the prepare_batch to correctly set the eval noise seed.
torch.manual_seed(0)
prepared_eval_batch = prepare_batch(eval_batch)
Expand All @@ -1855,7 +1858,7 @@ def execute_eval(
bsz = prepared_eval_batch["latents"].shape[0]
sample_text_str = "samples" if bsz > 1 else "sample"
with torch.no_grad():
eval_timestep_list = range(0, 1000, 25)
eval_timestep_list = range(0, 1000, self.config.eval_timestep_interval)
for eval_timestep in tqdm(
eval_timestep_list.__reversed__(),
total=len(eval_timestep_list),
Expand Down Expand Up @@ -1886,7 +1889,13 @@ def execute_eval(
)
accumulated_eval_losses[eval_timestep].append(eval_loss)
main_progress_bar.update(1)
torch.set_rng_state(training_random_states)
try:
reset_eval_datasets()
except:
pass
torch.set_rng_state(cpu_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
return accumulated_eval_losses

def generate_tracker_table(self, accumulated_evaluation_losses: dict):
Expand Down

0 comments on commit 4c7c968

Please sign in to comment.