From c29a0a35d29b8c8b56150a1184381ef1a3f3126e Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 17 Nov 2021 00:18:26 +0800 Subject: [PATCH] Update train.py --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index a90e720..c653d4f 100644 --- a/train.py +++ b/train.py @@ -228,15 +228,15 @@ lr = Freeze_lr start_epoch = Init_Epoch end_epoch = Freeze_Epoch - - train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path) - val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path) epoch_step = len(train_lines) // batch_size epoch_step_val = len(val_lines) // batch_size if epoch_step == 0 or epoch_step_val == 0: raise ValueError("数据集过小,无法进行训练,请扩充数据集。") + + train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path) + val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path) print('Train on {} samples, val on {} samples, with batch size {}.'.format(len(train_lines), len(val_lines), batch_size)) if eager: @@ -280,15 +280,15 @@ lr = Unfreeze_lr start_epoch = Freeze_Epoch end_epoch = UnFreeze_Epoch - - train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path) - val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path) epoch_step = len(train_lines) // batch_size epoch_step_val = len(val_lines) // batch_size if epoch_step == 0 or epoch_step_val == 0: raise ValueError("数据集过小,无法进行训练,请扩充数据集。") + + train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path) + val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path) print('Train on {} samples, val on {} samples, with batch size {}.'.format(len(train_lines), len(val_lines), batch_size)) if eager: