Skip to content

Commit

Permalink
Update RandomResizedCrop for masks
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Jul 12, 2024
1 parent 4711dbb commit 61e8550
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions src/otx/core/data/transform_libs/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,7 @@ class RandomResizedCrop(tvt_v2.Transform, NumpytoTVTensorMixin):
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bilinear'.
backend (str): The image resize backend type, accepted values are
'cv2' and 'pillow'. Defaults to 'cv2'.
transform_mask (bool): Whether to transform masks. Defaults to False.
is_numpy_to_tvtensor (bool): Whether convert outputs to tensor. Defaults to False.
"""

Expand All @@ -642,7 +641,7 @@ def __init__(
aspect_ratio_range: tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
max_attempts: int = 10,
interpolation: str = "bilinear",
backend: str = "cv2",
transform_mask: bool = False,
is_numpy_to_tvtensor: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -675,7 +674,7 @@ def __init__(
self.aspect_ratio_range = aspect_ratio_range
self.max_attempts = max_attempts
self.interpolation = interpolation
self.backend = backend
self.transform_mask = transform_mask
self.is_numpy_to_tvtensor = is_numpy_to_tvtensor

@cache_randomness
Expand Down Expand Up @@ -817,15 +816,7 @@ def _crop_img(
return patches

def forward(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
"""Transform function to randomly resized crop images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly resized cropped results, 'img_shape'
key in result dict is updated according to crop size.
"""
"""Transform function to randomly resized crop images and masks."""
inputs = _inputs[0]
if (img := getattr(inputs, "image", None)) is not None:
img = to_np_image(img)
Expand All @@ -840,26 +831,28 @@ def forward(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
)
img = self._crop_img(img, bboxes=bboxes)
inputs.img_info = _crop_image_info(inputs.img_info, *img.shape[:2])

img = cv2.resize(
img,
tuple(self.scale[::-1]),
dst=None,
interpolation=CV2_INTERP_CODES[self.interpolation],
)
if (masks := getattr(inputs, "gt_seg_map", None)) is not None:
masks = masks.numpy()
inputs.image = img
inputs.img_info = _resize_image_info(inputs.img_info, img.shape[:2])

if self.transform_mask and (masks := getattr(inputs, "masks", None)) is not None:
masks = to_np_image(masks)
masks = self._crop_img(masks, bboxes=bboxes)
masks = cv2.resize(
masks,
tuple(self.scale[::-1]),
dst=None,
interpolation=CV2_INTERP_CODES["nearest"],
)
inputs.gt_seg_map = torch.from_numpy(masks) # type: ignore[attr-defined]
if masks.ndim == 2:
masks = masks[None]
inputs.masks = tv_tensors.Mask(masks) # type: ignore[attr-defined]

inputs.image = img
inputs.img_info = _resize_image_info(inputs.img_info, img.shape[:2])
return self.convert(inputs)

def __repr__(self):
Expand All @@ -875,7 +868,8 @@ def __repr__(self):
repr_str += f"{tuple(round(r, 4) for r in self.aspect_ratio_range)}"
repr_str += f", max_attempts={self.max_attempts}"
repr_str += f", interpolation={self.interpolation}"
repr_str += f", backend={self.backend})"
repr_str += f", transform_mask={self.transform_mask}"
repr_str += f", is_numpy_to_tvtensor={self.is_numpy_to_tvtensor})"
return repr_str


Expand Down Expand Up @@ -2082,6 +2076,7 @@ class Pad(tvt_v2.Transform, NumpytoTVTensorMixin):
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
transform_mask (bool): Whether to transform masks. Defaults to False.
is_numpy_to_tvtensor (bool): Whether convert outputs to tensor. Defaults to False.
"""

Expand Down

0 comments on commit 61e8550

Please sign in to comment.