Skip to content

Commit

Permalink
Add Translate augmentation. (#538)
Browse files Browse the repository at this point in the history
* add imtranslate

* add imtranslate

* update comments

* reformat
  • Loading branch information
v-qjqs authored Sep 4, 2020
1 parent c698793 commit 9769024
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmcv/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple,
imrescale, imresize, imresize_like, imrotate, imshear,
rescale_size)
imtranslate, rescale_size)
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs
from .photometric import (imdenormalize, iminvert, imnormalize, imnormalize_,
Expand All @@ -18,5 +18,5 @@
'imwrite', 'supported_backends', 'use_backend', 'imdenormalize',
'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize',
'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', 'tensor2imgs',
'imshear'
'imshear', 'imtranslate'
]
69 changes: 69 additions & 0 deletions mmcv/image/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,72 @@ def imshear(img,
borderValue=border_value[:3],
flags=cv2_interp_codes[interpolation])
return sheared


def _get_translate_matrix(offset, direction='horizontal'):
"""Generate the translate matrix.
Args:
offset (int | float): The offset used for translate.
direction (str): The translate direction, either
"horizontal" or "vertical".
Returns:
ndarray: The translate matrix with dtype float32.
"""
if direction == 'horizontal':
translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
elif direction == 'vertical':
translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
return translate_matrix


def imtranslate(img,
offset,
direction='horizontal',
border_value=0,
interpolation='bilinear'):
"""Translate an image.
Args:
img (ndarray): Image to be translated with format
(h, w) or (h, w, c).
offset (int | float): The offset used for translate.
direction (str): The translate direction, either "horizontal"
or "vertical".
border_value (int | tuple[int]): Value used in case of a
constant border.
interpolation (str): Same as :func:`resize`.
Returns:
ndarray: The translated image.
"""
assert direction in ['horizontal',
'vertical'], f'Invalid direction: {direction}'
height, width = img.shape[:2]
if img.ndim == 2:
channels = 1
elif img.ndim == 3:
channels = img.shape[-1]
if isinstance(border_value, int):
border_value = tuple([border_value] * channels)
elif isinstance(border_value, tuple):
assert len(border_value) == channels, \
'Expected the num of elements in tuple equals the channels' \
'of input image. Found {} vs {}'.format(
len(border_value), channels)
else:
raise ValueError(
f'Invalid type {type(border_value)} for `border_value`.')
translate_matrix = _get_translate_matrix(offset, direction)
translated = cv2.warpAffine(
img,
translate_matrix,
(width, height),
# Note case when the number elements in `border_value`
# greater than 3 (e.g. translating masks whose channels
# large than 3) will raise TypeError in `cv2.warpAffine`.
# Here simply slice the first 3 values in `border_value`.
borderValue=border_value[:3],
flags=cv2_interp_codes[interpolation])
return translated
41 changes: 41 additions & 0 deletions tests/test_image/test_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,44 @@ def test_imshear(self):
# test invalid value of direction
with pytest.raises(AssertionError):
mmcv.imshear(img, 0.5, 'diagonal')

def test_imtranslate(self):
img = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8)
assert_array_equal(mmcv.imtranslate(img, 0), img)
# offset=1, horizontal
img_translated = np.array([[128, 1, 2], [128, 4, 5], [128, 7, 8]],
dtype=np.uint8)
assert_array_equal(
mmcv.imtranslate(img, 1, border_value=128), img_translated)
# offset=-1, vertical
img_translated = np.array([[4, 5, 6], [7, 8, 9], [0, 0, 0]],
dtype=np.uint8)
assert_array_equal(
mmcv.imtranslate(img, -1, 'vertical'), img_translated)
# offset=-2, horizontal
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
img_translated = [[3, 4, 128, 128], [7, 8, 128, 128]]
img_translated = np.stack(
[img_translated, img_translated, img_translated], axis=-1)
assert_array_equal(
mmcv.imtranslate(img, -2, border_value=128), img_translated)
# offset=2, vertical
border_value = (110, 120, 130)
img_translated = np.stack([
np.ones((2, 4)) * border_value[0],
np.ones((2, 4)) * border_value[1],
np.ones((2, 4)) * border_value[2]
],
axis=-1).astype(np.uint8)
assert_array_equal(
mmcv.imtranslate(img, 2, 'vertical', border_value), img_translated)
# test invalid number elements in border_value
with pytest.raises(AssertionError):
mmcv.imtranslate(img, 1, border_value=(1, ))
# test invalid type of border_value
with pytest.raises(ValueError):
mmcv.imtranslate(img, 1, border_value=[1, 2, 3])
# test invalid value of direction
with pytest.raises(AssertionError):
mmcv.imtranslate(img, 1, 'diagonal')

0 comments on commit 9769024

Please sign in to comment.