Skip to content

Commit

Permalink
(ROIs) Added Zones On Upper Level
Browse files Browse the repository at this point in the history
  • Loading branch information
ziqinyeow committed Jan 22, 2024
1 parent 02f60b4 commit a076e83
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 15 deletions.
41 changes: 29 additions & 12 deletions src/juxtapose/rtm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
DEVICE_TYPES,
)

from juxtapose.utils.polygon import PolygonZone

from juxtapose.utils.core import Detections
from juxtapose.utils.plotting import Annotator
from juxtapose.utils.roi import select_roi
Expand All @@ -34,6 +36,7 @@
colorstr,
ops,
get_time,
DEFAULT_COLOR_PALETTE,
)

from dataclasses import dataclass
Expand Down Expand Up @@ -137,23 +140,27 @@ def get_labels(self, detections: Detections):
]
)

def setup_roi(self, im, win: str = "roi", type: Literal["rect"] = "rect") -> List:
w, h, _ = im.shape
roi_zones, roi_zones_color = select_roi(im, win, type=type)
roi_zones = xyxy2xyxyxyxy(roi_zones)
def setup_zones(self, w, h, zones: List):
print(zones)
zones = [
(
sv.PolygonZone(
polygon=zone,
PolygonZone(
polygon=np.array(zone),
frame_resolution_wh=(w, h),
triggering_position=sv.Position.CENTER,
),
roi_zones_color[idx],
sv.Color(0, 0, 0).from_hex(DEFAULT_COLOR_PALETTE[idx]),
)
for idx, zone in enumerate(roi_zones)
for idx, zone in enumerate(zones)
]
return zones

def setup_roi(self, im, win: str = "roi", type: Literal["rect"] = "rect") -> List:
w, h, _ = im.shape
roi_zones = select_roi(im, win, type=type)
roi_zones = xyxy2xyxyxyxy(roi_zones)
return self.setup_zones(w, h, roi_zones)

@smart_inference_mode
def stream_inference(
self,
Expand All @@ -162,6 +169,7 @@ def stream_inference(
timer: List = [],
poi: Literal["point", "box", "text", ""] = "",
roi: Literal["rect", ""] = "",
zones: List = [],
# panels
show=True,
plot=True,
Expand Down Expand Up @@ -197,10 +205,16 @@ def stream_inference(
# reset tracker when source changed
if current_source is None:
current_source = p
if roi:
if len(zones) > 0:
w, h, _ = im.shape
zones = self.setup_zones(w, h, zones)
if roi and not zones:
zones = self.setup_roi(im)
elif current_source != p:
if roi:
if len(zones) > 0:
w, h, _ = im.shape
zones = self.setup_zones(w, h, zones)
if roi and not zones:
zones = self.setup_roi(im)
if self.tracker_type:
self.setup_tracker()
Expand All @@ -225,7 +239,7 @@ def stream_inference(
detections.track_id = track_id

# filter bboxes based on roi
if roi and zones and detections:
if zones and detections:
masks = set()
for zone, color in zones:
mask = zone.trigger(detections=detections)
Expand All @@ -243,7 +257,7 @@ def stream_inference(
if plot:
labels = self.get_labels(detections)

if roi and zones:
if zones:
for zone, color in zones:
im = sv.draw_polygon(im, zone.polygon, color)

Expand Down Expand Up @@ -352,6 +366,7 @@ def __call__(
timer: List = [],
poi: Literal["point", "box", "text", ""] = "",
roi: Literal["rect", ""] = "",
zones: List = [],
# panels
show=True,
plot=True,
Expand All @@ -367,6 +382,7 @@ def __call__(
timer=timer,
poi=poi,
roi=roi,
zones=zones,
show=show,
plot=plot,
plot_bboxes=plot_bboxes,
Expand All @@ -382,6 +398,7 @@ def __call__(
timer=timer,
poi=poi,
roi=roi,
zones=zones,
show=show,
plot=plot,
plot_bboxes=plot_bboxes,
Expand Down
21 changes: 21 additions & 0 deletions src/juxtapose/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,27 @@ def get_time():
return datetime.now().strftime("%Y%m%d-%H%M%S")


DEFAULT_COLOR_PALETTE = [
"#a351fb",
"#e6194b",
"#3cb44b",
"#ffe119",
"#0082c8",
"#f58231",
"#911eb4",
"#46f0f0",
"#f032e6",
"#d2f53c",
"#fabebe",
"#008080",
"#e6beff",
"#aa6e28",
"#fffac8",
"#800000",
"#aaffc3",
]


# Run below code on utils init ------------------------------------------------------------------------------------

# Check first-install steps
Expand Down
67 changes: 67 additions & 0 deletions src/juxtapose/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,36 @@
from supervision.geometry.core import Position


def calculate_masks_centroids(masks: np.ndarray) -> np.ndarray:
"""
Calculate the centroids of binary masks in a tensor.
Parameters:
masks (np.ndarray): A 3D NumPy array of shape (num_masks, height, width).
Each 2D array in the tensor represents a binary mask.
Returns:
A 2D NumPy array of shape (num_masks, 2), where each row contains the x and y
coordinates (in that order) of the centroid of the corresponding mask.
"""
num_masks, height, width = masks.shape
total_pixels = masks.sum(axis=(1, 2))

# offset for 1-based indexing
vertical_indices, horizontal_indices = np.indices((height, width)) + 0.5
# avoid division by zero for empty masks
total_pixels[total_pixels == 0] = 1

def sum_over_mask(indices: np.ndarray, axis: tuple) -> np.ndarray:
return np.tensordot(masks, indices, axes=axis)

aggregation_axis = ([1, 2], [0, 1])
centroid_x = sum_over_mask(horizontal_indices, aggregation_axis) / total_pixels
centroid_y = sum_over_mask(vertical_indices, aggregation_axis) / total_pixels

return np.column_stack((centroid_x, centroid_y)).astype(int)


def _validate_xyxy(xyxy: Any, n: int) -> None:
is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4)
if not is_valid:
Expand Down Expand Up @@ -291,6 +321,10 @@ def empty(cls) -> Detections:
class_id=np.array([], dtype=int),
)

def gg(self):
print("gg")
return

def get_anchor_coordinates(self, anchor: Position) -> np.ndarray:
"""
Returns the bounding box coordinates for a specific anchor.
Expand All @@ -301,17 +335,50 @@ def get_anchor_coordinates(self, anchor: Position) -> np.ndarray:
Returns:
np.ndarray: An array of shape `(n, 2)` containing the bounding box anchor coordinates in format `[x, y]`.
"""

if anchor == Position.CENTER:
return np.array(
[
(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2,
(self.xyxy[:, 1] + self.xyxy[:, 3]) / 2,
]
).transpose()
elif anchor == Position.CENTER_OF_MASS:
if self.mask is None:
raise ValueError(
"Cannot use `Position.CENTER_OF_MASS` without a detection mask."
)
return calculate_masks_centroids(masks=self.mask)
elif anchor == Position.CENTER_LEFT:
return np.array(
[
self.xyxy[:, 0],
(self.xyxy[:, 1] + self.xyxy[:, 3]) / 2,
]
).transpose()
elif anchor == Position.CENTER_RIGHT:
return np.array(
[
self.xyxy[:, 2],
(self.xyxy[:, 1] + self.xyxy[:, 3]) / 2,
]
).transpose()
elif anchor == Position.BOTTOM_CENTER:
return np.array(
[(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2, self.xyxy[:, 3]]
).transpose()
elif anchor == Position.BOTTOM_LEFT:
return np.array([self.xyxy[:, 0], self.xyxy[:, 3]]).transpose()
elif anchor == Position.BOTTOM_RIGHT:
return np.array([self.xyxy[:, 2], self.xyxy[:, 3]]).transpose()
elif anchor == Position.TOP_CENTER:
return np.array(
[(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2, self.xyxy[:, 1]]
).transpose()
elif anchor == Position.TOP_LEFT:
return np.array([self.xyxy[:, 0], self.xyxy[:, 1]]).transpose()
elif anchor == Position.TOP_RIGHT:
return np.array([self.xyxy[:, 2], self.xyxy[:, 1]]).transpose()

raise ValueError(f"{anchor} is not supported.")

Expand Down
125 changes: 125 additions & 0 deletions src/juxtapose/utils/polygon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import cv2
from dataclasses import replace
import numpy as np
from typing import Tuple
from .core import Detections

from enum import Enum


class Position(Enum):
"""
Enum representing the position of an anchor point.
"""

CENTER = "CENTER"
CENTER_LEFT = "CENTER_LEFT"
CENTER_RIGHT = "CENTER_RIGHT"
TOP_CENTER = "TOP_CENTER"
TOP_LEFT = "TOP_LEFT"
TOP_RIGHT = "TOP_RIGHT"
BOTTOM_LEFT = "BOTTOM_LEFT"
BOTTOM_CENTER = "BOTTOM_CENTER"
BOTTOM_RIGHT = "BOTTOM_RIGHT"
CENTER_OF_MASS = "CENTER_OF_MASS"

@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))


def clip_boxes(xyxy: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray:
"""
Clips bounding boxes coordinates to fit within the frame resolution.
Args:
xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each
row corresponds to a bounding box in
the format `(x_min, y_min, x_max, y_max)`.
resolution_wh (Tuple[int, int]): A tuple of the form `(width, height)`
representing the resolution of the frame.
Returns:
np.ndarray: A numpy array of shape `(N, 4)` where each row
corresponds to a bounding box with coordinates clipped to fit
within the frame resolution.
"""
result = np.copy(xyxy)
width, height = resolution_wh
result[:, [0, 2]] = result[:, [0, 2]].clip(0, width)
result[:, [1, 3]] = result[:, [1, 3]].clip(0, height)
return result


def polygon_to_mask(polygon: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray:
"""Generate a mask from a polygon.
Args:
polygon (np.ndarray): The polygon for which the mask should be generated,
given as a list of vertices.
resolution_wh (Tuple[int, int]): The width and height of the desired resolution.
Returns:
np.ndarray: The generated 2D mask, where the polygon is marked with
`1`'s and the rest is filled with `0`'s.
"""
width, height = resolution_wh
mask = np.zeros((height, width))

cv2.fillPoly(mask, [polygon], color=1)
return mask


class PolygonZone:
"""
A class for defining a polygon-shaped zone within a frame for detecting objects.
Attributes:
polygon (np.ndarray): A polygon represented by a numpy array of shape
`(N, 2)`, containing the `x`, `y` coordinates of the points.
frame_resolution_wh (Tuple[int, int]): The frame resolution (width, height)
triggering_position (Position): The position within the bounding
box that triggers the zone (default: Position.BOTTOM_CENTER)
current_count (int): The current count of detected objects within the zone
mask (np.ndarray): The 2D bool mask for the polygon zone
"""

def __init__(
self,
polygon: np.ndarray,
frame_resolution_wh: Tuple[int, int],
triggering_position: Position = Position.BOTTOM_CENTER,
):
self.polygon = polygon.astype(int)
self.frame_resolution_wh = frame_resolution_wh
self.triggering_position = triggering_position
self.current_count = 0

width, height = frame_resolution_wh
self.mask = polygon_to_mask(
polygon=polygon, resolution_wh=(width + 1, height + 1)
)

def trigger(self, detections: Detections) -> np.ndarray:
"""
Determines if the detections are within the polygon zone.
Parameters:
detections (Detections): The detections
to be checked against the polygon zone
Returns:
np.ndarray: A boolean numpy array indicating
if each detection is within the polygon zone
"""

clipped_xyxy = clip_boxes(
xyxy=detections.xyxy, resolution_wh=self.frame_resolution_wh
)
clipped_detections = replace(detections, xyxy=clipped_xyxy)
clipped_anchors = np.ceil(
clipped_detections.get_anchor_coordinates(anchor=self.triggering_position)
).astype(int)
is_in_zone = self.mask[clipped_anchors[:, 1], clipped_anchors[:, 0]]
self.current_count = int(np.sum(is_in_zone))
return is_in_zone.astype(bool)
5 changes: 2 additions & 3 deletions src/juxtapose/utils/roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def select_roi(
color=sv.ColorPalette.default(),
):
x0, y0 = -1, -1
points, bboxes, bboxes_color = [], [], []
points, bboxes = [], []
img, img4show = im.copy(), im.copy()
col = ()

Expand All @@ -26,7 +26,6 @@ def BOX(event, x, y, flags, param):
elif event == cv2.EVENT_LBUTTONUP:
img = img4show
bboxes.append([x0, y0, x, y])
bboxes_color.append(col)

def POINT(event, x, y, flags, param):
nonlocal x0, y0, img4show, img
Expand Down Expand Up @@ -74,6 +73,6 @@ def POINT(event, x, y, flags, param):
cv2.destroyWindow(win)

if type == "rect":
return bboxes, bboxes_color
return bboxes
else:
return points, [i + 1 for i in range(len(points))]

0 comments on commit a076e83

Please sign in to comment.