Skip to content

Commit

Permalink
update for corner cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 4, 2024
1 parent 321e24d commit 4dbcef4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def set_current_epoch(self, epoch):
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch
Expand Down Expand Up @@ -5560,6 +5561,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
while len(self.loss_list) <= step:
self.loss_list.append(0.0)
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
Expand Down
23 changes: 14 additions & 9 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,15 @@ def train(self, args):
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
if accelerator.is_main_process:
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
weights.pop(i)
if len(weights) > i:
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")

# save current ecpoch and step
Expand Down Expand Up @@ -813,11 +815,12 @@ def load_model_hook(models, input_dir):
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps

# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
else:
# if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
initial_step = 0 # do not skip

global_step = 0
Expand Down Expand Up @@ -878,9 +881,11 @@ def remove_model(old_ckpt_name):
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# training loop
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step

for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
Expand All @@ -892,7 +897,7 @@ def remove_model(old_ckpt_name):

skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1)
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1

for step, batch in enumerate(skipped_dataloader or train_dataloader):
Expand Down

0 comments on commit 4dbcef4

Please sign in to comment.