Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing authored Nov 16, 2021
1 parent a8d9129 commit c29a0a3
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,15 @@
lr = Freeze_lr
start_epoch = Init_Epoch
end_epoch = Freeze_Epoch

train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path)
val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path)

epoch_step = len(train_lines) // batch_size
epoch_step_val = len(val_lines) // batch_size

if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path)
val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path)

print('Train on {} samples, val on {} samples, with batch size {}.'.format(len(train_lines), len(val_lines), batch_size))
if eager:
Expand Down Expand Up @@ -280,15 +280,15 @@
lr = Unfreeze_lr
start_epoch = Freeze_Epoch
end_epoch = UnFreeze_Epoch

train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path)
val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path)

epoch_step = len(train_lines) // batch_size
epoch_step_val = len(val_lines) // batch_size

if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

train_dataloader = PSPnetDataset(train_lines, input_shape, batch_size, num_classes, aux_branch, True, VOCdevkit_path)
val_dataloader = PSPnetDataset(val_lines, input_shape, batch_size, num_classes, aux_branch, False, VOCdevkit_path)

print('Train on {} samples, val on {} samples, with batch size {}.'.format(len(train_lines), len(val_lines), batch_size))
if eager:
Expand Down

0 comments on commit c29a0a3

Please sign in to comment.