From 3cf8bc32fe0f6de9e5caf41c318873a746021af7 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:00:36 +0200 Subject: [PATCH] fix: on resume, preserve progress bar and current step/epoch --- train_network.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 0eaa7ac7e..499dac2bf 100644 --- a/train_network.py +++ b/train_network.py @@ -964,7 +964,7 @@ def load_model_hook(models, input_dir): ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" progress_bar = tqdm( - range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + initial=initial_step, total=args.max_train_steps, smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" ) epoch_to_start = 0 @@ -976,7 +976,6 @@ def load_model_hook(models, input_dir): f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") - initial_step *= args.gradient_accumulation_steps # set epoch to start to make initial_step less than len(train_dataloader) epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1045,7 +1044,8 @@ def remove_model(old_ckpt_name): # For --sample_at_first optimizer_eval_fn() - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + if args.sample_at_first and initial_step == 0: + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -1053,10 +1053,10 @@ def remove_model(old_ckpt_name): # training loop if initial_step > 0: # only if skip_until_initial_step is specified - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) global_step = initial_step + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step is {initial_step}") + initial_step -= math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")