Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 2 slicing utils #736

Merged
merged 5 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion sahi/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
from PIL import Image
from shapely.errors import TopologicalError
from tqdm import tqdm

from sahi.annotation import BoundingBox, Mask
from sahi.utils.coco import Coco, CocoAnnotation, CocoImage, create_coco_dict
from sahi.utils.cv import read_image_as_pil
from sahi.utils.file import load_json, save_json
Expand Down Expand Up @@ -642,3 +643,60 @@ def get_auto_slice_params(height: int, width: int):
return get_resolution_selector("high", height=height, width=width)
else:
return get_resolution_selector("ultra-high", height=height, width=width)


def shift_bboxes(bboxes, offset: Sequence[int]):
"""
Shift bboxes w.r.t offset.

Suppo

Args:
bboxes (Tensor, np.ndarray, list): The bboxes need to be translated. Its shape can
be (n, 4), which means (x, y, x, y).
offset (Sequence[int]): The translation offsets with shape of (2, ).
Returns:
Tensor, np.ndarray, list: Shifted bboxes.
"""
shifted_bboxes = []

if type(bboxes).__module__ == "torch":
bboxes_is_torch_tensor = True
else:
bboxes_is_torch_tensor = False

for bbox in bboxes:
if bboxes_is_torch_tensor or isinstance(bbox, np.ndarray):
bbox = bbox.tolist()
bbox = BoundingBox(bbox, shift_amount=offset)
bbox = bbox.get_shifted_box()
shifted_bboxes.append(bbox.to_xyxy())

if isinstance(bboxes, np.ndarray):
return np.stack(shifted_bboxes, axis=0)
elif bboxes_is_torch_tensor:
return bboxes.new_tensor(shifted_bboxes)
else:
return shifted_bboxes


def shift_masks(masks: np.ndarray, offset: Sequence[int], full_shape: Sequence[int]) -> np.ndarray:
"""Shift masks to the original image.
Args:
masks (np.ndarray): masks that need to be shifted.
offset (Sequence[int]): The offset to translate with shape of (2, ).
full_shape (Sequence[int]): A (height, width) tuple of the huge image's shape.
Returns:
np.ndarray: Shifted masks.
"""
# empty masks
if masks is None:
return masks

shifted_masks = []
for mask in masks:
mask = Mask(bool_mask=mask, shift_amount=offset, full_shape=full_shape)
mask = mask.get_shifted_mask()
shifted_masks.append(mask.bool_mask)

return np.stack(shifted_masks, axis=0)
15 changes: 5 additions & 10 deletions sahi/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from sahi.utils.import_utils import is_available

if is_available("torch"):
import torch
else:
torch = None


def empty_cuda_cache():
if is_torch_cuda_available():
import torch

return torch.cuda.empty_cache()


Expand All @@ -24,8 +27,6 @@ def to_float_tensor(img):
torch.tensor
"""

import torch

img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).float()
if img.max() > 1:
Expand All @@ -35,8 +36,6 @@ def to_float_tensor(img):


def torch_to_numpy(img):
import torch

img = img.numpy()
if img.max() > 1:
img /= 255
Expand All @@ -45,8 +44,6 @@ def torch_to_numpy(img):

def is_torch_cuda_available():
if is_available("torch"):
import torch

return torch.cuda.is_available()
else:
return False
Expand All @@ -65,8 +62,6 @@ def select_device(device: str):

Inspired by https://github.com/ultralytics/yolov5/blob/6371de8879e7ad7ec5283e8b95cc6dd85d6a5e72/utils/torch_utils.py#L107
"""
import torch

if device == "cuda":
device = "cuda:0"
device = str(device).strip().lower().replace("none", "") # to string, 'cuda:0' to '0'
Expand Down
31 changes: 30 additions & 1 deletion tests/test_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from PIL import Image

from sahi.slicing import slice_coco, slice_image
from sahi.slicing import shift_bboxes, shift_masks, slice_coco, slice_image
from sahi.utils.coco import Coco
from sahi.utils.cv import read_image

Expand Down Expand Up @@ -170,6 +170,35 @@ def test_slice_coco(self):

shutil.rmtree(output_dir, ignore_errors=True)

def test_shift_bboxes(self):
import torch

bboxes = [[1, 2, 3, 4]]
shift_x = 10
shift_y = 20
shifted_bboxes = shift_bboxes(bboxes=bboxes, offset=[shift_x, shift_y])
self.assertEqual(shifted_bboxes, [[11, 22, 13, 24]])
self.assertEqual(type(shifted_bboxes), list)

bboxes = np.array([[1, 2, 3, 4]])
shifted_bboxes = shift_bboxes(bboxes=bboxes, offset=[shift_x, shift_y])
self.assertEqual(shifted_bboxes.tolist(), [[11, 22, 13, 24]])
self.assertEqual(type(shifted_bboxes), np.ndarray)

bboxes = torch.tensor([[1, 2, 3, 4]])
shifted_bboxes = shift_bboxes(bboxes=bboxes, offset=[shift_x, shift_y])
self.assertEqual(shifted_bboxes.tolist(), [[11, 22, 13, 24]])
self.assertEqual(type(shifted_bboxes), torch.Tensor)

def test_shift_masks(self):
masks = np.zeros((3, 30, 30), dtype=np.bool)
shift_x = 10
shift_y = 20
full_shape = [720, 1280]
shifted_masks = shift_masks(masks=masks, offset=[shift_x, shift_y], full_shape=full_shape)
self.assertEqual(shifted_masks.shape, (3, 720, 1280))
self.assertEqual(type(shifted_masks), np.ndarray)


if __name__ == "__main__":
unittest.main()