Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add target to resize transform for aspect ratio training (detection task) #823

Merged
merged 17 commits into from
Mar 9, 2022
Merged
56 changes: 48 additions & 8 deletions doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import math
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
from PIL.Image import Image
from torch.nn.functional import pad
Expand All @@ -22,15 +23,24 @@ def __init__(
interpolation=F.InterpolationMode.BILINEAR,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
pad: bool = True,
) -> None:
super().__init__(size, interpolation)
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.pad = pad

def forward(
self,
img: torch.Tensor,
target: Optional[np.ndarray] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, np.ndarray]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, I think we should have:

  • one function to resize the image
  • one function to resize the target
  • use them in the module implementation
    And I'm not sure what would be the best, but probably having Resize (image only), and SampleResize (image + target), would be relevant. What do you think?

Copy link
Collaborator Author

@charlesmindee charlesmindee Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can split the function with 1 for the targets and 1 for the image, but I am not sure it is relevant to have a Resize and a SampleResize, because if you want to keep the aspect ratio of the image while resizing, you will almost always pad (except if you give only 1 target size), and this will modify the targets. The Resize function won't be able to preserve the aspect ratio of the images, and this would require to differentiate the function we are using in the training scripts and in the preprocessor regarding the option selected by the user (preserve_aspect_ratio or not). This can even be dangerous if someone tries to modify the aspect ratio in Resize without changing the targets I think, do yo agree ?

Copy link
Collaborator Author

@charlesmindee charlesmindee Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for a differentiation of the target and image function, I think it will complexify the code because we use many attributes of the Resize class to resize both the image and the target (self.symmetric_pad, self.preserve_aspect_ratio, self.output_size) which would be passed as arguments and we have also many shared computations for the image and the target that would need to be done twice (or added in the signature of the functions but I think it will make the code less understandable): offset, raw_shape, and img.shape is also used in the target computation, what do you think @fg-mindee ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About splitting the two transforms, the reason behind this is that the target, depending on the task, will not necessarily be modified. So either we can add a lot of cases (for each target type) in the Resize transform, or we could have a Resize that only transforms the image, and SampleResize that inherits from it and transforms the target.

I agree this would be dangerous for something to add this without changing the target, but transforms will only be played with by people willing to train models. So I would argue it's safe to assume they either use our default training script or have a good understanding of what is going on under the hood 🤷‍♂️

However, I fully agree about the duplication of symmetric pad / computation duplication for some aspects. I don't have an ideal suggestion for this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the SampleResize transform, you want it to inherit from Resize, but since they won't have the same signature so I don't really see why we do that ? @fg-mindee

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can always refactor this later on anyway! Feel free to implement it the way you prefer (the optional target passing you suggested is probably the best one)


def forward(self, img: torch.Tensor) -> torch.Tensor:
target_ratio = self.size[0] / self.size[1]
actual_ratio = img.shape[-2] / img.shape[-1]
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
if target is not None:
return super().forward(img), target
return super().forward(img)
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
else:
# Resize
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -41,12 +51,42 @@ def forward(self, img: torch.Tensor) -> torch.Tensor:

# Scale image
img = F.resize(img, tmp_size, self.interpolation)
# Pad (inverted in pytorch)
_pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2])
if self.symmetric_pad:
half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
_pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
return pad(img, _pad)
raw_shape = img.shape[-2:]
if self.pad:
# Pad (inverted in pytorch)
_pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2])
if self.symmetric_pad:
half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
_pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
img = pad(img, _pad)

# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
if target is not None:
if self.preserve_aspect_ratio:
# Get absolute coords
if target.shape[1:] == (4,):
if self.pad and self.symmetric_pad:
if np.max(target) <= 1:
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
else:
target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
elif target.shape[1:] == (4, 2):
if self.pad and self.symmetric_pad:
if np.max(target) <= 1:
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
else:
target[..., 0] *= raw_shape[-1] / img.shape[-1]
target[..., 1] *= raw_shape[-2] / img.shape[-2]
else:
raise AssertionError
return img, target

return img

def __repr__(self) -> str:
interpolate_str = self.interpolation.value
Expand Down
56 changes: 46 additions & 10 deletions doctr/transforms/modules/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import random
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -63,30 +63,66 @@ def __init__(
method: str = 'bilinear',
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
pad: bool = True
) -> None:
self.output_size = output_size
self.method = method
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.pad = pad

def extra_repr(self) -> str:
_repr = f"output_size={self.output_size}, method='{self.method}'"
if self.preserve_aspect_ratio:
_repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}"
return _repr

def __call__(self, img: tf.Tensor) -> tf.Tensor:
def __call__(
self,
img: tf.Tensor,
target: Optional[np.ndarray] = None,
) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]:

input_dtype = img.dtype
img = tf.image.resize(img, self.output_size, self.method, self.preserve_aspect_ratio)
raw_shape = img.shape[:2]
if self.preserve_aspect_ratio:
# pad width
if not self.symmetric_pad:
offset = (0, 0)
elif self.output_size[0] == img.shape[0]:
offset = (0, int((self.output_size[1] - img.shape[1]) / 2))
else:
offset = (int((self.output_size[0] - img.shape[0]) / 2), 0)
img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size)
if self.pad:
# pad width
if not self.symmetric_pad:
offset = (0, 0)
elif self.output_size[0] == img.shape[0]:
offset = (0, int((self.output_size[1] - img.shape[1]) / 2))
else:
offset = (int((self.output_size[0] - img.shape[0]) / 2), 0)
img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size)

# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
if target is not None:
if self.preserve_aspect_ratio:
# Get absolute coords
if target.shape[1:] == (4,):
if self.pad and self.symmetric_pad:
if np.max(target) <= 1:
offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1]
target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0]
else:
target[:, [0, 2]] *= raw_shape[1] / img.shape[1]
target[:, [1, 3]] *= raw_shape[0] / img.shape[0]
elif target.shape[1:] == (4, 2):
if self.pad and self.symmetric_pad:
if np.max(target) <= 1:
offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1]
target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0]
else:
target[..., 0] *= raw_shape[1] / img.shape[1]
target[..., 1] *= raw_shape[0] / img.shape[0]
else:
raise AssertionError
return tf.cast(img, dtype=input_dtype), target

return tf.cast(img, dtype=input_dtype)


Expand Down
17 changes: 10 additions & 7 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def main(args):
val_set = DetectionDataset(
img_folder=os.path.join(args.val_path, 'images'),
label_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, args.input_size)),
img_transforms=T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True),
use_polygons=args.rotation and not args.eval_straight,
)
val_loader = DataLoader(
Expand Down Expand Up @@ -243,17 +243,20 @@ def main(args):
img_folder=os.path.join(args.train_path, 'images'),
label_path=os.path.join(args.train_path, 'labels.json'),
img_transforms=Compose(
([T.Resize((args.input_size, args.input_size))] if not args.rotation else [])
+ [
[
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]
),
sample_transforms=T.SampleCompose([
T.RandomRotate(90, expand=True),
T.ImageTransform(T.Resize((args.input_size, args.input_size))),
]) if args.rotation else None,
sample_transforms=T.SampleCompose(
([T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True, pad=True)
] if not args.rotation else [])
+ ([T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, pad=False),
T.RandomRotate(90, expand=True),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True, pad=True)
] if args.rotation else [])
),
use_polygons=args.rotation,
)

Expand Down
17 changes: 10 additions & 7 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def main(args):
val_set = DetectionDataset(
img_folder=os.path.join(args.val_path, 'images'),
label_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, args.input_size)),
img_transforms=T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True),
use_polygons=args.rotation and not args.eval_straight,
)
val_loader = DataLoader(
Expand Down Expand Up @@ -189,8 +189,7 @@ def main(args):
img_folder=os.path.join(args.train_path, 'images'),
label_path=os.path.join(args.train_path, 'labels.json'),
img_transforms=T.Compose(
([T.Resize((args.input_size, args.input_size))] if not args.rotation else [])
+ [
[
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
T.RandomJpegQuality(60),
Expand All @@ -199,10 +198,14 @@ def main(args):
T.RandomBrightness(.3),
]
),
sample_transforms=T.SampleCompose([
T.RandomRotate(90, expand=True),
T.ImageTransform(T.Resize((args.input_size, args.input_size))),
]) if args.rotation else None,
sample_transforms=T.SampleCompose(
([T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True, pad=True)
] if not args.rotation else [])
+ ([T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, pad=False),
T.RandomRotate(90, expand=True),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True, pad=True)
] if args.rotation else [])
),
use_polygons=args.rotation,
)
train_loader = DataLoader(
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_models_detection_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob):
np.array([[.75, .75, .5, .5, 0], [.65, .7, .3, .4, 0]], dtype=np.float32),
]
loss = model(input_tensor, target, training=True)['loss']
assert isinstance(loss, tf.Tensor) and ((loss - out['loss']) / loss).numpy() < 21e-2
assert isinstance(loss, tf.Tensor) and ((loss - out['loss']) / loss).numpy() < 25e-2


@pytest.fixture(scope="session")
Expand Down