From 28a75214cbf200b6b80d33a5f67e0eed29a1ce94 Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Mon, 27 Dec 2021 14:03:06 +0100 Subject: [PATCH 1/7] feat: add rotation in training scripts --- doctr/transforms/functional/tensorflow.py | 4 +++- references/detection/train_pytorch.py | 5 ++++- references/detection/train_tensorflow.py | 5 ++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py index 972a752004..485478e4a6 100644 --- a/doctr/transforms/functional/tensorflow.py +++ b/doctr/transforms/functional/tensorflow.py @@ -52,6 +52,8 @@ def rotate( if expand: exp_shape = compute_expanded_shape(img.shape[:-1], angle) h_pad, w_pad = int(math.ceil(exp_shape[0] - img.shape[0])), int(math.ceil(exp_shape[1] - img.shape[1])) + if min(h_pad, w_pad) < 0: + h_pad, w_pad = int(math.ceil(exp_shape[1] - img.shape[0])), int(math.ceil(exp_shape[0] - img.shape[1])) exp_img = tf.pad(img, tf.constant([[h_pad // 2, h_pad - h_pad // 2], [w_pad // 2, w_pad - w_pad // 2], [0, 0]])) else: exp_img = img @@ -72,7 +74,7 @@ def rotate( r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[1] r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[0] - return rotated_img, r_boxes + return rotated_img, np.clip(r_boxes, 0, 1) def crop_detection( diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index d502f0fa05..eefb283a44 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -237,11 +237,14 @@ def main(args): img_folder=os.path.join(args.train_path, 'images'), label_path=os.path.join(args.train_path, 'labels.json'), img_transforms=Compose([ - T.Resize((args.input_size, args.input_size)), # Augmentations T.RandomApply(T.ColorInversion(), .1), ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), ]), + sample_transforms=T.SampleCompose([ + T.RandomRotate(90, expand=True), + T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]), ) train_loader = DataLoader( diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 2bc9478993..aad6faa308 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -184,7 +184,6 @@ def main(args): img_folder=os.path.join(args.train_path, 'images'), label_path=os.path.join(args.train_path, 'labels.json'), img_transforms=T.Compose([ - T.Resize((args.input_size, args.input_size)), # Augmentations T.RandomApply(T.ColorInversion(), .1), T.RandomJpegQuality(60), @@ -192,6 +191,10 @@ def main(args): T.RandomContrast(.3), T.RandomBrightness(.3), ]), + sample_transforms=T.SampleCompose([ + T.RandomRotate(90, expand=True), + T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]), ) train_loader = DataLoader( train_set, From 0f6756bd1adf2a0704300bb0e4a0c461dba3e988 Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Mon, 27 Dec 2021 16:41:20 +0100 Subject: [PATCH 2/7] fix: transforms bug --- doctr/datasets/detection.py | 16 ++++++++--- .../differentiable_binarization/base.py | 27 ++++++++++++++++--- doctr/transforms/functional/pytorch.py | 11 ++++---- doctr/transforms/functional/tensorflow.py | 13 +++++---- doctr/utils/geometry.py | 9 +++---- references/detection/train_pytorch.py | 4 +-- references/detection/train_tensorflow.py | 4 +-- references/detection/utils.py | 6 ++--- 8 files changed, 57 insertions(+), 33 deletions(-) diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index bcd50d76ce..e8128c2251 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -51,7 +51,6 @@ def __init__( polygons = np.asarray(label['polygons']) geoms = polygons if use_polygons else np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1) - self.data.append((img_name, np.asarray(geoms, dtype=np.float32))) def __getitem__( @@ -59,8 +58,11 @@ def __getitem__( index: int ) -> Tuple[Any, np.ndarray]: + + img, target = self._read_sample(index) h, w = self._get_img_shape(img) + if self.img_transforms is not None: img = self.img_transforms(img) @@ -69,8 +71,14 @@ def __getitem__( # Boxes target = target.copy() - target[..., 0] /= w - target[..., 1] /= h - target = target.clip(0, 1) + if np.max(target) > 2: + if target.ndim == 3: + target[..., 0] /= w + target[..., 1] /= h + else: + target[..., [0, 2]] /= w + target[..., [1, 3]] /= h + + target = target.clip(0, 1) return img, target diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 5ae007f714..cac65290de 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -312,7 +312,14 @@ def build_target( for box, box_size, poly in zip(abs_boxes, boxes_size, polys): # Mask boxes that are too small if box_size < self.min_size_box: - seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + if abs_boxes.ndim == 3: + seg_mask[ + idx, + int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1, + int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1 + ] = False + else: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue # Negative shrink for gt, as described in paper @@ -325,11 +332,25 @@ def build_target( # Draw polygon on gt if it is valid if len(shrinked) == 0: - seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + if abs_boxes.ndim == 3: + seg_mask[ + idx, + int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1, + int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1 + ] = False + else: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue shrinked = np.array(shrinked[0]).reshape(-1, 2) if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + if abs_boxes.ndim == 3: + seg_mask[ + idx, + int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1, + int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1 + ] = False + else: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1) diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index 1ac15410e6..71c2409b03 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -51,17 +51,16 @@ def rotate( # Get absolute coords _boxes = deepcopy(boxes) - if boxes.dtype != int: + if np.max(_boxes) < 2: _boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[2] _boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[1] # Rotate the boxes: xmin, ymin, xmax, ymax --> (4, 2) polygon - r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:], expand) # type: ignore[arg-type] + r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:], expand).astype(np.float32) # type: ignore[arg-type] - # Convert them to relative - if boxes.dtype != int: - r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[2] - r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[1] + # Always return relative boxes to avoid label confusions when resizing is performed aferwards + r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[2] + r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[1] return rotated_img, r_boxes diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py index 485478e4a6..8719adb845 100644 --- a/doctr/transforms/functional/tensorflow.py +++ b/doctr/transforms/functional/tensorflow.py @@ -62,19 +62,18 @@ def rotate( # Get absolute coords _boxes = deepcopy(boxes) - if boxes.dtype != int: + if np.max(_boxes) < 2: _boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[1] _boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[0] # Rotate the boxes: xmin, ymin, xmax, ymax --> (4, 2) polygon - r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1], expand) + r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1], expand).astype(np.float32) - # Convert them to relative - if boxes.dtype != int: - r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[1] - r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[0] + # Always return relative boxes to avoid label confusions when resizing is performed aferwards + r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[1] + r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[0] - return rotated_img, np.clip(r_boxes, 0, 1) + return rotated_img, r_boxes def crop_detection( diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index c5869877ff..36b4fd5a13 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -118,15 +118,14 @@ def rotate_abs_boxes(boxes: np.ndarray, angle: float, img_shape: Tuple[int, int] ) # Rotate them around image center, shape (N+1, 4, 2) - stacked_rel_points = rotate_abs_points(stacked_rel_points.reshape((-1, 2))).reshape((-1, 4, 2)) - rot_points = rotate_abs_points(stacked_rel_points, angle) - img_rot_corners, box_rot_corners = rot_points[:1], rot_points[1:] + rot_points = rotate_abs_points(stacked_rel_points.reshape((-1, 2)), angle).reshape(-1, 4, 2) + img_rot_corners, box_rot_corners = rot_points[0], rot_points[1:] # Expand the image to fit all the original info if expand: new_corners = np.abs(img_rot_corners).max(axis=0) - box_rot_corners[..., 0] += new_corners[:, 0] - box_rot_corners[..., 1] = new_corners[:, 1] - box_rot_corners[..., 1] + box_rot_corners[..., 0] += new_corners[0] + box_rot_corners[..., 1] = new_corners[1] - box_rot_corners[..., 1] else: box_rot_corners[..., 0] += img_shape[1] / 2 box_rot_corners[..., 1] = img_shape[0] / 2 - box_rot_corners[..., 1] diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index eefb283a44..f706ce72ff 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -244,7 +244,7 @@ def main(args): sample_transforms=T.SampleCompose([ T.RandomRotate(90, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]), + ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), ) train_loader = DataLoader( @@ -372,7 +372,7 @@ def parse_args(): parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='Load pretrained parameters before starting the training') parser.add_argument('--rotation', dest='rotation', action='store_true', - help='train with rotated bbox') + help='train with rotated documents') parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index aad6faa308..5eeb04284e 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -194,7 +194,7 @@ def main(args): sample_transforms=T.SampleCompose([ T.RandomRotate(90, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]), + ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), ) train_loader = DataLoader( train_set, @@ -322,7 +322,7 @@ def parse_args(): parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='Load pretrained parameters before starting the training') parser.add_argument('--rotation', dest='rotation', action='store_true', - help='train with rotated bbox') + help='train with rotated documents') parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') args = parser.parse_args() diff --git a/references/detection/utils.py b/references/detection/utils.py index c81dfcf168..7a5a355caa 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -26,10 +26,8 @@ def plot_samples(images, targets: List[Dict[str, np.ndarray]]) -> None: boxes[:, :4] = boxes[:, :4].round().astype(int) for box in boxes: - if boxes.shape[1] == 5: - box = cv2.boxPoints(((int(box[0]), int(box[1])), (int(box[2]), int(box[3])), -box[4])) - box = np.int0(box) - cv2.fillPoly(target, [box], 1) + if boxes.ndim == 3: + cv2.fillPoly(target, [np.int0(box)], 1) else: target[int(box[1]): int(box[3]) + 1, int(box[0]): int(box[2]) + 1] = 1 From f763c32498ca641da5bcd2c1c42b1512a923dcc3 Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Tue, 28 Dec 2021 10:16:19 +0100 Subject: [PATCH 3/7] feat: integrate rotation to training scripts --- doctr/datasets/detection.py | 6 +++--- doctr/utils/metrics.py | 2 +- references/detection/train_pytorch.py | 6 +++++- references/detection/train_tensorflow.py | 6 +++++- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index e8128c2251..b0927a21d0 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -58,8 +58,6 @@ def __getitem__( index: int ) -> Tuple[Any, np.ndarray]: - - img, target = self._read_sample(index) h, w = self._get_img_shape(img) @@ -67,6 +65,8 @@ def __getitem__( img = self.img_transforms(img) if self.sample_transforms is not None: + # Here we may modify coordinates, each transformation must accept/return relative coordinates + # Otherwise, if we use the resize operation afterwards it will not only modify images but coordinates img, target = self.sample_transforms(img, target) # Boxes @@ -79,6 +79,6 @@ def __getitem__( target[..., [0, 2]] /= w target[..., [1, 3]] /= h - target = target.clip(0, 1) + target = target.clip(0, 1) return img, target diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index 3f166ed967..42b52e2f68 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -203,7 +203,7 @@ def polygon_iou( polys_1: np.ndarray, polys_2: np.ndarray, mask_shape: Tuple[int, int], - use_broadcasting: bool = True + use_broadcasting: bool = False ) -> np.ndarray: """Computes the IoU between two sets of rotated bounding boxes diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index f706ce72ff..d65a60b73e 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -156,7 +156,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): loc_preds = out['preds'] for boxes_gt, boxes_pred in zip(targets, loc_preds): # Remove scores - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) val_loss += out['loss'].item() batch_cnt += 1 @@ -180,6 +180,10 @@ def main(args): img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), img_transforms=T.Resize((args.input_size, args.input_size)), + sample_transforms=T.SampleCompose([ + T.RandomRotate(0, expand=True), + T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), ) val_loader = DataLoader( val_set, diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 5eeb04284e..832aeb309c 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -114,7 +114,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): loc_preds = out['preds'] for boxes_gt, boxes_pred in zip(targets, loc_preds): # Remove scores - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) val_loss += out['loss'].numpy() batch_cnt += 1 @@ -140,6 +140,10 @@ def main(args): img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), img_transforms=T.Resize((args.input_size, args.input_size)), + sample_transforms=T.SampleCompose([ + T.RandomRotate(0, expand=True), + T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), ) val_loader = DataLoader( val_set, From a11ad5e36e288d58f5ffe2e6a621bacadaaba54c Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Tue, 28 Dec 2021 10:23:13 +0100 Subject: [PATCH 4/7] fix: add empty line --- doctr/datasets/detection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index b0927a21d0..0da2fad99b 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -51,6 +51,7 @@ def __init__( polygons = np.asarray(label['polygons']) geoms = polygons if use_polygons else np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1) + self.data.append((img_name, np.asarray(geoms, dtype=np.float32))) def __getitem__( From 3ea1500689f171ff283c65ca63d925fbd02c6b87 Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Wed, 29 Dec 2021 12:02:22 +0100 Subject: [PATCH 5/7] fix: training scripts --- .../differentiable_binarization/base.py | 28 +++---------------- references/detection/train_pytorch.py | 19 +++++++------ references/detection/train_tensorflow.py | 25 +++++++++-------- 3 files changed, 29 insertions(+), 43 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index cac65290de..fc21676cb2 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -297,6 +297,7 @@ def build_target( abs_boxes[:, :, 1] *= output_shape[-2] polys = abs_boxes boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1) else: abs_boxes[:, [0, 2]] *= output_shape[-1] abs_boxes[:, [1, 3]] *= output_shape[-2] @@ -312,14 +313,7 @@ def build_target( for box, box_size, poly in zip(abs_boxes, boxes_size, polys): # Mask boxes that are too small if box_size < self.min_size_box: - if abs_boxes.ndim == 3: - seg_mask[ - idx, - int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1, - int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1 - ] = False - else: - seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue # Negative shrink for gt, as described in paper @@ -332,25 +326,11 @@ def build_target( # Draw polygon on gt if it is valid if len(shrinked) == 0: - if abs_boxes.ndim == 3: - seg_mask[ - idx, - int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1, - int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1 - ] = False - else: - seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue shrinked = np.array(shrinked[0]).reshape(-1, 2) if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - if abs_boxes.ndim == 3: - seg_mask[ - idx, - int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1, - int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1 - ] = False - else: - seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1) diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index d65a60b73e..3cb659d0d9 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -179,11 +179,11 @@ def main(args): val_set = DetectionDataset( img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), - img_transforms=T.Resize((args.input_size, args.input_size)), + img_transforms=T.Resize((args.input_size, args.input_size)) if not args.rotation else None, sample_transforms=T.SampleCompose([ T.RandomRotate(0, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else None ) val_loader = DataLoader( val_set, @@ -240,15 +240,18 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, 'images'), label_path=os.path.join(args.train_path, 'labels.json'), - img_transforms=Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), .1), - ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), - ]), + img_transforms=Compose( + ([T.Resize((args.input_size, args.input_size))] if not args.rotation else []) + + [ + # Augmentations + T.RandomApply(T.ColorInversion(), .1), + ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + ] + ), sample_transforms=T.SampleCompose([ T.RandomRotate(90, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else None ) train_loader = DataLoader( diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 832aeb309c..c9f85f6191 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -139,11 +139,11 @@ def main(args): val_set = DetectionDataset( img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), - img_transforms=T.Resize((args.input_size, args.input_size)), + img_transforms=T.Resize((args.input_size, args.input_size)) if not args.rotation else None, sample_transforms=T.SampleCompose([ T.RandomRotate(0, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else None ) val_loader = DataLoader( val_set, @@ -187,18 +187,21 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, 'images'), label_path=os.path.join(args.train_path, 'labels.json'), - img_transforms=T.Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), .1), - T.RandomJpegQuality(60), - T.RandomSaturation(.3), - T.RandomContrast(.3), - T.RandomBrightness(.3), - ]), + img_transforms=T.Compose( + ([T.Resize((args.input_size, args.input_size))] if not args.rotation else []) + + [ + # Augmentations + T.RandomApply(T.ColorInversion(), .1), + T.RandomJpegQuality(60), + T.RandomSaturation(.3), + T.RandomContrast(.3), + T.RandomBrightness(.3), + ] + ), sample_transforms=T.SampleCompose([ T.RandomRotate(90, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else None ) train_loader = DataLoader( train_set, From 631b2b292a3759b9315ed374d31172c105cd1d29 Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Wed, 29 Dec 2021 15:06:29 +0100 Subject: [PATCH 6/7] fix: requested changes --- doctr/datasets/detection.py | 2 +- doctr/transforms/functional/pytorch.py | 2 +- doctr/transforms/functional/tensorflow.py | 2 +- references/detection/train_pytorch.py | 7 ++----- references/detection/train_tensorflow.py | 7 ++----- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index 4b21891797..abe37952eb 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -72,7 +72,7 @@ def __getitem__( # Boxes target = target.copy() - if np.max(target) > 2: + if np.max(target) > 1: # Absolute coords if target.ndim == 3: target[..., 0] /= w target[..., 1] /= h diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index 71c2409b03..c2c498ce81 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -51,7 +51,7 @@ def rotate( # Get absolute coords _boxes = deepcopy(boxes) - if np.max(_boxes) < 2: + if np.max(_boxes) <= 1: _boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[2] _boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[1] diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py index 8719adb845..0ab5eb3b32 100644 --- a/doctr/transforms/functional/tensorflow.py +++ b/doctr/transforms/functional/tensorflow.py @@ -62,7 +62,7 @@ def rotate( # Get absolute coords _boxes = deepcopy(boxes) - if np.max(_boxes) < 2: + if np.max(_boxes) <= 1: _boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[1] _boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[0] diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 3cb659d0d9..b79d17e1ef 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -179,11 +179,8 @@ def main(args): val_set = DetectionDataset( img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), - img_transforms=T.Resize((args.input_size, args.input_size)) if not args.rotation else None, - sample_transforms=T.SampleCompose([ - T.RandomRotate(0, expand=True), - T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else None + img_transforms=T.Resize((args.input_size, args.input_size)), + use_polygons=True if args.rotation else False ) val_loader = DataLoader( val_set, diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index c9f85f6191..3bfa193368 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -139,11 +139,8 @@ def main(args): val_set = DetectionDataset( img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), - img_transforms=T.Resize((args.input_size, args.input_size)) if not args.rotation else None, - sample_transforms=T.SampleCompose([ - T.RandomRotate(0, expand=True), - T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else None + img_transforms=T.Resize((args.input_size, args.input_size)), + use_polygons=True if args.rotation else False ) val_loader = DataLoader( val_set, From 16c98a8deb750b579505f4ce76dfa24f6a1fb251 Mon Sep 17 00:00:00 2001 From: charlesmindee Date: Wed, 29 Dec 2021 16:39:12 +0100 Subject: [PATCH 7/7] fix: requested changes --- doctr/transforms/functional/pytorch.py | 2 +- doctr/transforms/functional/tensorflow.py | 2 +- doctr/utils/geometry.py | 30 ++++++++++++----------- references/detection/train_pytorch.py | 5 ++-- references/detection/train_tensorflow.py | 5 ++-- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index c2c498ce81..aa3a962b4c 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -55,7 +55,7 @@ def rotate( _boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[2] _boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[1] - # Rotate the boxes: xmin, ymin, xmax, ymax --> (4, 2) polygon + # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:], expand).astype(np.float32) # type: ignore[arg-type] # Always return relative boxes to avoid label confusions when resizing is performed aferwards diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py index 0ab5eb3b32..bf06b183c3 100644 --- a/doctr/transforms/functional/tensorflow.py +++ b/doctr/transforms/functional/tensorflow.py @@ -66,7 +66,7 @@ def rotate( _boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[1] _boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[0] - # Rotate the boxes: xmin, ymin, xmax, ymax --> (4, 2) polygon + # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1], expand).astype(np.float32) # Always return relative boxes to avoid label confusions when resizing is performed aferwards diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index 92f8e07287..61045b4819 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -86,10 +86,11 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in def rotate_abs_boxes(boxes: np.ndarray, angle: float, img_shape: Tuple[int, int], expand: bool = True) -> np.ndarray: - """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax) by an angle around the image center. + """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax) or polygons (N, 4, 2) + by an angle around the image center. Args: - boxes: (N, 4) array of absolute coordinate boxes + boxes: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes angle: angle between -90 and +90 degrees img_shape: the height and width of the image expand: whether the image should be padded to avoid information loss @@ -99,15 +100,18 @@ def rotate_abs_boxes(boxes: np.ndarray, angle: float, img_shape: Tuple[int, int] """ # Get box corners - box_corners = np.stack( - [ - boxes[:, [0, 1]], - boxes[:, [2, 1]], - boxes[:, [2, 3]], - boxes[:, [0, 3]], - ], - axis=1 - ) + if boxes.ndim == 2: + box_corners = np.stack( + [ + boxes[:, [0, 1]], + boxes[:, [2, 1]], + boxes[:, [2, 3]], + boxes[:, [0, 3]], + ], + axis=1 + ) + else: + box_corners = boxes img_corners = np.array([[0, 0], [0, img_shape[0]], [*img_shape[::-1]], [img_shape[1], 0]], dtype=boxes.dtype) stacked_points = np.concatenate((img_corners[None, ...], box_corners), axis=0) @@ -149,7 +153,6 @@ def rotate_boxes( angle: angle between -90 and +90 degrees orig_shape: shape of the origin image min_angle: minimum angle to rotate boxes - target_shape: shape of the target image Returns: A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes @@ -157,7 +160,7 @@ def rotate_boxes( # Change format of the boxes to rotated boxes _boxes = loc_preds.copy() - if _boxes.shape[1] == 5: + if _boxes.ndim == 2: _boxes = np.stack( [ _boxes[:, [0, 1]], @@ -183,7 +186,6 @@ def rotate_boxes( rotated_boxes = np.stack( (rotated_points[:, :, 0] / orig_shape[1], rotated_points[:, :, 1] / orig_shape[0]), axis=-1 ) - return rotated_boxes diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index b79d17e1ef..f2874f343f 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -180,7 +180,7 @@ def main(args): img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), img_transforms=T.Resize((args.input_size, args.input_size)), - use_polygons=True if args.rotation else False + use_polygons=args.rotation, ) val_loader = DataLoader( val_set, @@ -248,7 +248,8 @@ def main(args): sample_transforms=T.SampleCompose([ T.RandomRotate(90, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else None + ]) if args.rotation else None, + use_polygons=args.rotation, ) train_loader = DataLoader( diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 3bfa193368..00736773eb 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -140,7 +140,7 @@ def main(args): img_folder=os.path.join(args.val_path, 'images'), label_path=os.path.join(args.val_path, 'labels.json'), img_transforms=T.Resize((args.input_size, args.input_size)), - use_polygons=True if args.rotation else False + use_polygons=args.rotation, ) val_loader = DataLoader( val_set, @@ -198,7 +198,8 @@ def main(args): sample_transforms=T.SampleCompose([ T.RandomRotate(90, expand=True), T.ImageTransform(T.Resize((args.input_size, args.input_size))), - ]) if args.rotation else None + ]) if args.rotation else None, + use_polygons=args.rotation, ) train_loader = DataLoader( train_set,