Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add rotation option to both detection training scripts #765

Merged
merged 10 commits into from
Dec 29, 2021
13 changes: 11 additions & 2 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,25 @@ def __getitem__(

img, target = self._read_sample(index)
h, w = self._get_img_shape(img)

if self.img_transforms is not None:
img = self.img_transforms(img)

if self.sample_transforms is not None:
# Here we may modify coordinates, each transformation must accept/return relative coordinates
# Otherwise, if we use the resize operation afterwards it will not only modify images but coordinates
img, target = self.sample_transforms(img, target)

# Boxes
target = target.copy()
target[..., 0] /= w
target[..., 1] /= h
if np.max(target) > 2:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
if target.ndim == 3:
target[..., 0] /= w
target[..., 1] /= h
else:
target[..., [0, 2]] /= w
target[..., [1, 3]] /= h

target = target.clip(0, 1)

return img, target
27 changes: 24 additions & 3 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,14 @@ def build_target(
for box, box_size, poly in zip(abs_boxes, boxes_size, polys):
# Mask boxes that are too small
if box_size < self.min_size_box:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
if abs_boxes.ndim == 3:
seg_mask[
idx,
int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1,
int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1
] = False
else:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
continue

# Negative shrink for gt, as described in paper
Expand All @@ -325,11 +332,25 @@ def build_target(

# Draw polygon on gt if it is valid
if len(shrinked) == 0:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
if abs_boxes.ndim == 3:
seg_mask[
idx,
int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1,
int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1
] = False
else:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
if abs_boxes.ndim == 3:
seg_mask[
idx,
int(np.min(box[:, 1])): int(np.max(box[:, 1])) + 1,
int(np.min(box[:, 0])): int(np.max(box[:, 0])) + 1
] = False
else:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1)

Expand Down
11 changes: 5 additions & 6 deletions doctr/transforms/functional/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@ def rotate(

# Get absolute coords
_boxes = deepcopy(boxes)
if boxes.dtype != int:
if np.max(_boxes) < 2:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
_boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[2]
_boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[1]

# Rotate the boxes: xmin, ymin, xmax, ymax --> (4, 2) polygon
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:], expand) # type: ignore[arg-type]
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:], expand).astype(np.float32) # type: ignore[arg-type]

# Convert them to relative
if boxes.dtype != int:
r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[2]
r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[1]
# Always return relative boxes to avoid label confusions when resizing is performed aferwards
r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[2]
r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[1]

return rotated_img, r_boxes

Expand Down
13 changes: 7 additions & 6 deletions doctr/transforms/functional/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def rotate(
if expand:
exp_shape = compute_expanded_shape(img.shape[:-1], angle)
h_pad, w_pad = int(math.ceil(exp_shape[0] - img.shape[0])), int(math.ceil(exp_shape[1] - img.shape[1]))
if min(h_pad, w_pad) < 0:
h_pad, w_pad = int(math.ceil(exp_shape[1] - img.shape[0])), int(math.ceil(exp_shape[0] - img.shape[1]))
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
exp_img = tf.pad(img, tf.constant([[h_pad // 2, h_pad - h_pad // 2], [w_pad // 2, w_pad - w_pad // 2], [0, 0]]))
else:
exp_img = img
Expand All @@ -60,17 +62,16 @@ def rotate(

# Get absolute coords
_boxes = deepcopy(boxes)
if boxes.dtype != int:
if np.max(_boxes) < 2:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
_boxes[:, [0, 2]] = _boxes[:, [0, 2]] * img.shape[1]
_boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[0]

# Rotate the boxes: xmin, ymin, xmax, ymax --> (4, 2) polygon
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1], expand)
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1], expand).astype(np.float32)

# Convert them to relative
if boxes.dtype != int:
r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[1]
r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[0]
# Always return relative boxes to avoid label confusions when resizing is performed aferwards
r_boxes[..., 0] = r_boxes[..., 0] / rotated_img.shape[1]
r_boxes[..., 1] = r_boxes[..., 1] / rotated_img.shape[0]

return rotated_img, r_boxes

Expand Down
9 changes: 4 additions & 5 deletions doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,14 @@ def rotate_abs_boxes(boxes: np.ndarray, angle: float, img_shape: Tuple[int, int]
)

# Rotate them around image center, shape (N+1, 4, 2)
stacked_rel_points = rotate_abs_points(stacked_rel_points.reshape((-1, 2))).reshape((-1, 4, 2))
rot_points = rotate_abs_points(stacked_rel_points, angle)
img_rot_corners, box_rot_corners = rot_points[:1], rot_points[1:]
rot_points = rotate_abs_points(stacked_rel_points.reshape((-1, 2)), angle).reshape(-1, 4, 2)
img_rot_corners, box_rot_corners = rot_points[0], rot_points[1:]

# Expand the image to fit all the original info
if expand:
new_corners = np.abs(img_rot_corners).max(axis=0)
box_rot_corners[..., 0] += new_corners[:, 0]
box_rot_corners[..., 1] = new_corners[:, 1] - box_rot_corners[..., 1]
box_rot_corners[..., 0] += new_corners[0]
box_rot_corners[..., 1] = new_corners[1] - box_rot_corners[..., 1]
else:
box_rot_corners[..., 0] += img_shape[1] / 2
box_rot_corners[..., 1] = img_shape[0] / 2 - box_rot_corners[..., 1]
Expand Down
2 changes: 1 addition & 1 deletion doctr/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def polygon_iou(
polys_1: np.ndarray,
polys_2: np.ndarray,
mask_shape: Tuple[int, int],
use_broadcasting: bool = True
use_broadcasting: bool = False
) -> np.ndarray:
"""Computes the IoU between two sets of rotated bounding boxes

Expand Down
13 changes: 10 additions & 3 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
loc_preds = out['preds']
for boxes_gt, boxes_pred in zip(targets, loc_preds):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4])

val_loss += out['loss'].item()
batch_cnt += 1
Expand All @@ -180,6 +180,10 @@ def main(args):
img_folder=os.path.join(args.val_path, 'images'),
label_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, args.input_size)),
sample_transforms=T.SampleCompose([
T.RandomRotate(0, expand=True),
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
T.ImageTransform(T.Resize((args.input_size, args.input_size))),
]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))),
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
)
val_loader = DataLoader(
val_set,
Expand Down Expand Up @@ -237,11 +241,14 @@ def main(args):
img_folder=os.path.join(args.train_path, 'images'),
label_path=os.path.join(args.train_path, 'labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, args.input_size)),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
sample_transforms=T.SampleCompose([
T.RandomRotate(90, expand=True),
T.ImageTransform(T.Resize((args.input_size, args.input_size))),
]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))),
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
)

train_loader = DataLoader(
Expand Down Expand Up @@ -369,7 +376,7 @@ def parse_args():
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
help='train with rotated documents')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR')
Expand Down
13 changes: 10 additions & 3 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
loc_preds = out['preds']
for boxes_gt, boxes_pred in zip(targets, loc_preds):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4])

val_loss += out['loss'].numpy()
batch_cnt += 1
Expand All @@ -140,6 +140,10 @@ def main(args):
img_folder=os.path.join(args.val_path, 'images'),
label_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, args.input_size)),
sample_transforms=T.SampleCompose([
T.RandomRotate(0, expand=True),
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
T.ImageTransform(T.Resize((args.input_size, args.input_size))),
]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))),
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
)
val_loader = DataLoader(
val_set,
Expand Down Expand Up @@ -184,14 +188,17 @@ def main(args):
img_folder=os.path.join(args.train_path, 'images'),
label_path=os.path.join(args.train_path, 'labels.json'),
img_transforms=T.Compose([
T.Resize((args.input_size, args.input_size)),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
sample_transforms=T.SampleCompose([
T.RandomRotate(90, expand=True),
T.ImageTransform(T.Resize((args.input_size, args.input_size))),
]) if args.rotation else T.ImageTransform(T.Resize((args.input_size, args.input_size))),
)
train_loader = DataLoader(
train_set,
Expand Down Expand Up @@ -319,7 +326,7 @@ def parse_args():
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
help='train with rotated documents')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR')
args = parser.parse_args()
Expand Down
6 changes: 2 additions & 4 deletions references/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ def plot_samples(images, targets: List[Dict[str, np.ndarray]]) -> None:
boxes[:, :4] = boxes[:, :4].round().astype(int)

for box in boxes:
if boxes.shape[1] == 5:
box = cv2.boxPoints(((int(box[0]), int(box[1])), (int(box[2]), int(box[3])), -box[4]))
box = np.int0(box)
cv2.fillPoly(target, [box], 1)
if boxes.ndim == 3:
cv2.fillPoly(target, [np.int0(box)], 1)
else:
target[int(box[1]): int(box[3]) + 1, int(box[0]): int(box[2]) + 1] = 1

Expand Down