Skip to content

Commit

Permalink
[references] Update detection augmentations (#1577)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored May 3, 2024
1 parent 2940d9d commit 56db176
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 71 deletions.
10 changes: 6 additions & 4 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ class OneOf(NestedObject):
def __init__(self, transforms: List[Callable[[Any], Any]]) -> None:
self.transforms = transforms

def __call__(self, img: Any) -> Any:
def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
# Pick transformation
transfo = self.transforms[int(random.random() * len(self.transforms))]
# Apply
return transfo(img)
return transfo(img) if target is None else transfo(img, target) # type: ignore[call-arg]


class RandomApply(NestedObject):
Expand Down Expand Up @@ -286,10 +286,12 @@ def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
if target.shape[1:] == (4, 2):
min_xy = np.min(target, axis=1)
max_xy = np.max(target, axis=1)
target = np.concatenate((min_xy, max_xy), axis=1)
_target = np.concatenate((min_xy, max_xy), axis=1)
else:
_target = target

# Crop image and targets
croped_img, crop_boxes = F.crop_detection(img, target, crop_box)
croped_img, crop_boxes = F.crop_detection(img, _target, crop_box)
# hard fallback if no box is kept
if crop_boxes.shape[0] == 0:
return img, target
Expand Down
76 changes: 44 additions & 32 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,42 +256,54 @@ def main(args):
return

st = time.time()
# Augmentations
# Image augmentations
img_transforms = T.OneOf([
Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.2),
]),
Compose([
T.RandomApply(T.RandomShadow(), 0.3),
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.3),
RandomGrayscale(p=0.15),
]),
RandomPhotometricDistort(p=0.3),
lambda x: x, # Identity no transformation
])
# Image + target augmentations
sample_transforms = T.SampleCompose(
(
[
T.RandomHorizontalFlip(0.15),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), p=0.25),
]),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if not args.rotation
else [
T.RandomHorizontalFlip(0.15),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), p=0.25),
]),
# Rotation augmentation
T.Resize(args.input_size, preserve_aspect_ratio=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
)
)

# Load both train and val data generators
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=3), 0.1),
RandomPhotometricDistort(p=0.05),
RandomGrayscale(p=0.05),
]),
sample_transforms=T.SampleCompose(
(
[
T.RandomHorizontalFlip(0.1),
T.RandomApply(T.RandomCrop(), 0.2),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if not args.rotation
else [
T.RandomHorizontalFlip(0.1),
T.RandomApply(T.RandomCrop(), 0.2),
]
)
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation
else []
)
),
img_transforms=img_transforms,
sample_transforms=sample_transforms,
use_polygons=args.rotation,
)

Expand Down
83 changes: 48 additions & 35 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,45 +209,58 @@ def main(args):
return

st = time.time()
# Augmentations
# Image augmentations
img_transforms = T.OneOf([
T.Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.2),
]),
T.Compose([
T.RandomApply(T.RandomJpegQuality(60), 0.15),
# T.RandomApply(T.RandomShadow(), 0.2), # Broken atm on GPU
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.15),
]),
T.Compose([
T.RandomApply(T.RandomSaturation(0.3), 0.3),
T.RandomApply(T.RandomContrast(0.3), 0.3),
T.RandomApply(T.RandomBrightness(0.3), 0.3),
]),
lambda x: x, # Identity no transformation
])
# Image + target augmentations
sample_transforms = T.SampleCompose(
(
[
T.RandomHorizontalFlip(0.15),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), p=0.25),
]),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if not args.rotation
else [
T.RandomHorizontalFlip(0.15),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), p=0.25),
]),
# Rotation augmentation
T.Resize(args.input_size, preserve_aspect_ratio=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
)
)
# Load both train and val data generators
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=T.Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
# T.RandomApply(T.RandomShadow(), 0.1), # Broken atm on GPU
T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.05),
]),
sample_transforms=T.SampleCompose(
(
[
T.RandomHorizontalFlip(0.1),
T.RandomApply(T.RandomCrop(), 0.2),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if not args.rotation
else [
T.RandomHorizontalFlip(0.1),
T.RandomApply(T.RandomCrop(), 0.2),
]
)
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation
else []
)
),
img_transforms=img_transforms,
sample_transforms=sample_transforms,
use_polygons=args.rotation,
)
train_loader = DataLoader(
Expand Down
5 changes: 5 additions & 0 deletions tests/common/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def test_oneof():
transfo = T.OneOf(transfos)
out = transfo(1)
assert out == 0 or out == 11
# test with target
transfos = [lambda x, y: (1 - x, y), lambda x, y: (x + 10, y)]
transfo = T.OneOf(transfos)
out = transfo(1, np.array([2]))
assert out == (0, 2) or out == (11, 2) and isinstance(out[1], np.ndarray)


def test_randomapply():
Expand Down

0 comments on commit 56db176

Please sign in to comment.