Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Mar 11, 2024
1 parent c6e70fb commit 8963aad
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 44 deletions.
38 changes: 21 additions & 17 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,28 @@ def __init__(

conv_layer = DeformConv2d if deform_conv else nn.Conv2d

self.in_branches = nn.ModuleList([
nn.Sequential(
conv_layer(chans, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
for idx, chans in enumerate(in_channels)
])
self.in_branches = nn.ModuleList(
[
nn.Sequential(
conv_layer(chans, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
for idx, chans in enumerate(in_channels)
]
)
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.out_branches = nn.ModuleList([
nn.Sequential(
conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
)
for idx, chans in enumerate(in_channels)
])
self.out_branches = nn.ModuleList(
[
nn.Sequential(
conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
)
for idx, chans in enumerate(in_channels)
]
)

def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
if len(x) != len(self.out_branches):
Expand Down
34 changes: 19 additions & 15 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,17 @@ def main(args):
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=3), 0.1),
RandomPhotometricDistort(p=0.05),
RandomGrayscale(p=0.05),
]),
img_transforms=Compose(
[
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=3), 0.1),
RandomPhotometricDistort(p=0.05),
RandomGrayscale(p=0.05),
]
),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
Expand Down Expand Up @@ -390,12 +392,14 @@ def main(args):
print(log_msg)
# W&B
if args.wb:
wandb.log({
"val_loss": val_loss,
"recall": recall,
"precision": precision,
"mean_iou": mean_iou,
})
wandb.log(
{
"val_loss": val_loss,
"recall": recall,
"precision": precision,
"mean_iou": mean_iou,
}
)
if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break
Expand Down
26 changes: 14 additions & 12 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,18 +221,20 @@ def main(args):
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=T.Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
#T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
#T.RandomApply(T.RandomShadow(), 0.1),
#T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.05),
]),
img_transforms=T.Compose(
[
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
# T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
# T.RandomApply(T.RandomShadow(), 0.1),
# T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.05),
]
),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
Expand Down

0 comments on commit 8963aad

Please sign in to comment.