diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 9e4b81ef92..17686bb28e 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -57,24 +57,28 @@ def __init__( conv_layer = DeformConv2d if deform_conv else nn.Conv2d - self.in_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(chans, out_channels, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.in_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(chans, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.out_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(out_channels, out_chans, 3, padding=1, bias=False), - nn.BatchNorm2d(out_chans), - nn.ReLU(inplace=True), - nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.out_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(out_channels, out_chans, 3, padding=1, bias=False), + nn.BatchNorm2d(out_chans), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: if len(x) != len(self.out_branches): diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 4f64011518..e3fe2c9f8d 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -266,15 +266,17 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, "images"), label_path=os.path.join(args.train_path, "labels.json"), - img_transforms=Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), - T.RandomApply(T.RandomShadow(), 0.1), - T.RandomApply(GaussianBlur(kernel_size=3), 0.1), - RandomPhotometricDistort(p=0.05), - RandomGrayscale(p=0.05), - ]), + img_transforms=Compose( + [ + # Augmentations + T.RandomApply(T.ColorInversion(), 0.1), + T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), + T.RandomApply(T.RandomShadow(), 0.1), + T.RandomApply(GaussianBlur(kernel_size=3), 0.1), + RandomPhotometricDistort(p=0.05), + RandomGrayscale(p=0.05), + ] + ), sample_transforms=T.SampleCompose( ( [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] @@ -390,12 +392,14 @@ def main(args): print(log_msg) # W&B if args.wb: - wandb.log({ - "val_loss": val_loss, - "recall": recall, - "precision": precision, - "mean_iou": mean_iou, - }) + wandb.log( + { + "val_loss": val_loss, + "recall": recall, + "precision": precision, + "mean_iou": mean_iou, + } + ) if args.early_stop and early_stopper.early_stop(val_loss): print("Training halted early due to reaching patience limit.") break diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 21b27d1bed..fd0271c188 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -221,18 +221,20 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, "images"), label_path=os.path.join(args.train_path, "labels.json"), - img_transforms=T.Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - T.RandomJpegQuality(60), - #T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), - #T.RandomApply(T.RandomShadow(), 0.1), - #T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1), - T.RandomSaturation(0.3), - T.RandomContrast(0.3), - T.RandomBrightness(0.3), - T.RandomApply(T.ToGray(num_output_channels=3), 0.05), - ]), + img_transforms=T.Compose( + [ + # Augmentations + T.RandomApply(T.ColorInversion(), 0.1), + T.RandomJpegQuality(60), + # T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), + # T.RandomApply(T.RandomShadow(), 0.1), + # T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1), + T.RandomSaturation(0.3), + T.RandomContrast(0.3), + T.RandomBrightness(0.3), + T.RandomApply(T.ToGray(num_output_channels=3), 0.05), + ] + ), sample_transforms=T.SampleCompose( ( [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]