diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py new file mode 100644 index 0000000000..6a3fee30ac --- /dev/null +++ b/doctr/models/detection/_utils/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/detection/_utils/pytorch.py b/doctr/models/detection/_utils/pytorch.py new file mode 100644 index 0000000000..456f9dfaa7 --- /dev/null +++ b/doctr/models/detection/_utils/pytorch.py @@ -0,0 +1,37 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from torch import Tensor +from torch.nn.functional import max_pool2d + +__all__ = ['erode', 'dilate'] + + +def erode(x: Tensor, kernel_size: int) -> Tensor: + """Performs erosion on a given tensor + + Args: + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for erosion + Returns: + the eroded tensor + """ + _pad = (kernel_size - 1) // 2 + + return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad) + + +def dilate(x: Tensor, kernel_size: int) -> Tensor: + """Performs dilation on a given tensor + + Args: + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for dilation + Returns: + the dilated tensor + """ + _pad = (kernel_size - 1) // 2 + + return max_pool2d(x, kernel_size, stride=1, padding=_pad) diff --git a/doctr/models/detection/_utils/tensorflow.py b/doctr/models/detection/_utils/tensorflow.py new file mode 100644 index 0000000000..0cadfd7eaa --- /dev/null +++ b/doctr/models/detection/_utils/tensorflow.py @@ -0,0 +1,34 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +__all__ = ['erode', 'dilate'] + + +def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs erosion on a given tensor + + Args: + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for erosion + Returns: + the eroded tensor + """ + + return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME") + + +def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs dilation on a given tensor + + Args: + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for dilation + Returns: + the dilated tensor + """ + + return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME") diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 13fa79f980..66dd2507de 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -3,6 +3,7 @@ import torch from doctr.models import detection +from doctr.models.detection._utils import dilate, erode from doctr.models.detection.predictor import DetectionPredictor @@ -67,3 +68,19 @@ def test_detection_zoo(arch_name): with torch.no_grad(): out = predictor(input_tensor) assert all(isinstance(boxes, np.ndarray) and boxes.shape[1] == 5 for boxes in out) + + +def test_erode(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.zeros((1, 1, 3, 3)) + out = erode(x, 3) + assert torch.equal(out, expected) + + +def test_dilate(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.ones((1, 1, 3, 3)) + out = dilate(x, 3) + assert torch.equal(out, expected) diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 6f475885b2..69df2d25f8 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -4,6 +4,7 @@ from doctr.io import DocumentFile from doctr.models import detection +from doctr.models.detection._utils import dilate, erode from doctr.models.detection.predictor import DetectionPredictor from doctr.models.preprocessor import PreProcessor @@ -139,3 +140,21 @@ def test_linknet_focal_loss(): # test focal loss out = model(input_tensor, target, return_model_output=True, return_boxes=True, training=True, focal_loss=True) assert isinstance(out['loss'], tf.Tensor) + + +def test_erode(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.zeros((1, 3, 3, 1)) + out = erode(x, 3) + assert tf.math.reduce_all(out == expected) + + +def test_dilate(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.ones((1, 3, 3, 1)) + out = dilate(x, 3) + assert tf.math.reduce_all(out == expected)