Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New frame selection model #195

Merged
merged 36 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f57d3e8
new model and kwargs
ejm714 Jun 29, 2022
7370b7e
WIP load new model, remove old exp
ejm714 Jun 29, 2022
81f16fa
support video detection and postprocess as a single array
ejm714 Jun 29, 2022
b69c9f0
black
ejm714 Jun 29, 2022
a4eea90
fix tests
ejm714 Jun 29, 2022
e4858a2
remove old object detection model
ejm714 Jun 29, 2022
972da83
fix tests
ejm714 Jun 29, 2022
e13997a
use image size from config
ejm714 Jun 29, 2022
4b308b1
source for tiny exp
ejm714 Jun 30, 2022
38d44e5
first pass at distributed training
ejm714 Jun 30, 2022
8ddbc7d
use normal array
ejm714 Jul 1, 2022
41ac90d
remove padding and just resize
ejm714 Jul 1, 2022
d46bfb4
resize instead of pad
ejm714 Jul 1, 2022
1e4ebff
flake8
ejm714 Jul 1, 2022
2d6b798
remove extra code
ejm714 Jul 1, 2022
0da2b30
three boxes found for dog image
ejm714 Jul 1, 2022
2bea54f
cleanup
ejm714 Jul 1, 2022
9f6e2cb
fix test
ejm714 Jul 1, 2022
ce4bac8
scale and pad
ejm714 Jul 1, 2022
ffcb842
fix tests
ejm714 Jul 1, 2022
6efe15a
black
ejm714 Jul 1, 2022
15fc91d
remove extra code
ejm714 Jul 1, 2022
2522ba3
torch.no_grad is super duper important
ejm714 Jul 2, 2022
55c11d9
iterate over batches of 64
ejm714 Jul 2, 2022
229f1a8
black
ejm714 Jul 2, 2022
0bf851d
spacing
ejm714 Jul 2, 2022
0d12b9a
decrease batch size so this fits
Jul 2, 2022
ed9a8d6
set crf to default
ejm714 Jul 11, 2022
6b21005
act2 is now part of bn2
ejm714 Jul 11, 2022
b641e0d
update manifest
ejm714 Jul 14, 2022
cad32d9
Merge branch 'new-frame-selection-model' of github.com:drivendataorg/…
ejm714 Jul 14, 2022
19f9541
set frame batch size as param
ejm714 Jul 14, 2022
92b9b07
preallocate
ejm714 Jul 14, 2022
16ff182
add api reference for mdlite
ejm714 Jul 14, 2022
506e7ad
fix all object detection links
ejm714 Jul 14, 2022
9e0d5cd
update manifest
ejm714 Jul 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_load_video_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def test_megadetector_lite_yolox_dog(tmp_path):
"-vcodec",
"libx264",
"-crf",
"25",
"23",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set CRF to default (otherwise test fails due to lossy compression in creating the test video): https://trac.ffmpeg.org/wiki/Encode/H.264

"-pix_fmt",
"yuv420p",
str(tmp_path / "dog.mp4"),
Expand Down
25 changes: 19 additions & 6 deletions tests/test_megadetector_lite_yolox.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json

import numpy as np
from PIL import Image
import pytest
import torch

from zamba.object_detection import YoloXNano
from zamba.object_detection import YoloXModel, YoloXExp, YoloXArgs
from zamba.object_detection.yolox.megadetector_lite_yolox import (
MegadetectorLiteYoloX,
MegadetectorLiteYoloXConfig,
Expand All @@ -19,14 +21,25 @@ def dog():

@pytest.fixture
def dummy_yolox_path(tmp_path):
yolox = YoloXNano(num_classes=1)
checkpoint = {"model": yolox.get_model().state_dict()}
yolox = YoloXModel(exp=YoloXExp(num_classes=1), args=YoloXArgs())
checkpoint = {"model": yolox.exp.get_model().state_dict()}
torch.save(checkpoint, tmp_path / "dummy_yolox.pth")
return tmp_path / "dummy_yolox.pth"


def test_load_megadetector(dummy_yolox_path):
MegadetectorLiteYoloX(dummy_yolox_path, MegadetectorLiteYoloXConfig())
@pytest.fixture
def dummy_yolox_model_kwargs(tmp_path):
kwargs = dict(num_classes=1, image_size=640, backbone="yolox-tiny")
json_path = tmp_path / "dummy_yolox_kwargs.json"
with json_path.open("w+") as f:
json.dump(kwargs, f)
return json_path


def test_load_megadetector(dummy_yolox_path, dummy_yolox_model_kwargs):
MegadetectorLiteYoloX(
dummy_yolox_path, dummy_yolox_model_kwargs, MegadetectorLiteYoloXConfig()
)


def test_scale_and_pad_array():
Expand All @@ -51,7 +64,7 @@ def test_detect_image(mdlite, dog):
boxes, scores = mdlite.detect_image(np.array(dog))

assert len(scores) == 1
assert np.allclose([0.09690314, 0.04301501, 0.9931333, 1.0082883], boxes[0])
assert np.allclose([0.65678996, 0.21596366, 0.71104807, 0.277931], boxes[0])


def test_detect_video(mdlite, dog):
Expand Down
2 changes: 1 addition & 1 deletion zamba/data/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def load_video_frames(

if config.megadetector_lite_config is not None:
mdlite = MegadetectorLiteYoloX(config=config.megadetector_lite_config)
detection_probs = mdlite.detect_video(frames=arr)
detection_probs = mdlite.detect_video(video_arr=arr)

arr = mdlite.filter_frames(arr, detection_probs)

Expand Down
5 changes: 2 additions & 3 deletions zamba/object_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from zamba.object_detection.yolox.yolox_base import YoloXBase
from zamba.object_detection.yolox.yolox_nano import YoloXNano
from zamba.object_detection.yolox.yolox_model import YoloXArgs, YoloXExp, YoloXModel

__all__ = ["YoloXBase", "YoloXNano"]
__all__ = ["YoloXArgs", "YoloXExp", "YoloXModel"]
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
"num_classes": 1,
"backbone": "yolox-tiny",
"image_size": 640
}
94 changes: 73 additions & 21 deletions zamba/object_detection/yolox/megadetector_lite_yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from tqdm import tqdm
from yolox.utils.boxes import postprocess

from zamba.object_detection import YoloXNano
from zamba.object_detection import YoloXModel

LOCAL_MD_LITE_MODEL = Path(__file__).parent / "assets" / "yolox_nano_20210901.pth"
LOCAL_MD_LITE_MODEL = Path(__file__).parent / "assets" / "yolox_tiny_640_20220528.pth"
LOCAL_MD_LITE_MODEL_KWARGS = (
Path(__file__).parent / "assets" / "yolox_tiny_640_20220528_model_kwargs.json"
)


class FillModeEnum(str, Enum):
Expand Down Expand Up @@ -56,8 +59,8 @@ class MegadetectorLiteYoloXConfig(BaseModel):

confidence: float = 0.25
nms_threshold: float = 0.45
image_width: int = 416
image_height: int = 416
image_width: int = 640
image_height: int = 640
device: str = "cuda" if torch.cuda.is_available() else "cpu"
n_frames: Optional[int] = None
fill_mode: Optional[FillModeEnum] = FillModeEnum.score_sorted
Expand All @@ -72,6 +75,7 @@ class MegadetectorLiteYoloX:
def __init__(
self,
path: os.PathLike = LOCAL_MD_LITE_MODEL,
kwargs: os.PathLike = LOCAL_MD_LITE_MODEL_KWARGS,
config: Optional[Union[MegadetectorLiteYoloXConfig, dict]] = None,
):
"""MegadetectorLite based on YOLOX.
Expand All @@ -85,18 +89,20 @@ def __init__(
elif isinstance(config, dict):
config = MegadetectorLiteYoloXConfig.parse_obj(config)

checkpoint = torch.load(path, map_location=config.device)
num_classes = checkpoint["model"]["head.cls_preds.0.weight"].shape[0]
yolox = YoloXModel.load(
checkpoint=path,
model_kwargs_path=kwargs,
)

yolox = YoloXNano(num_classes=num_classes)
model = yolox.get_model()
model.load_state_dict(checkpoint["model"])
ckpt = torch.load(yolox.args.ckpt, map_location=config.device)
model = yolox.exp.get_model()
model.load_state_dict(ckpt["model"])
model = model.eval().to(config.device)

self.model = model
self.yolox = yolox
self.config = config
self.num_classes = num_classes
self.num_classes = yolox.exp.num_classes

@staticmethod
def scale_and_pad_array(
Expand All @@ -116,19 +122,63 @@ def _preprocess(self, frame: np.ndarray) -> np.ndarray:
"""Process an image for the model, including scaling/padding the image, transposing from
(height, width, channel) to (channel, height, width) and casting to float.
"""
return np.ascontiguousarray(
self.scale_and_pad_array(
frame, self.config.image_width, self.config.image_height
).transpose(2, 0, 1),
arr = np.ascontiguousarray(
self.scale_and_pad_array(frame, self.config.image_width, self.config.image_height),
dtype=np.float32,
)
return np.moveaxis(arr, 2, 0)

def _preprocess_video(self, video: np.ndarray) -> np.ndarray:
"""Process a video for the model, including resizing the frames in the video, transposing
from (batch, height, width, channel) to (batch, channel, height, width) and casting to float.
"""
resized_frames = []
for frame_idx in range(video.shape[0]):
frame = video[frame_idx]
resized_frames.append(self._preprocess(frame))

return np.array(resized_frames)
ejm714 marked this conversation as resolved.
Show resolved Hide resolved

def detect_video(self, video_arr: np.ndarray, pbar: bool = False):
"""Runs object detection on an video.

Args:
video_arr (np.ndarray): An video array with dimensions (frames, height, width, channels).

Returns:
list: A list containing detections and score for each frame. Each tuple contains two arrays:
the first is an array of bounding box detections with dimensions (object, 4) where
object is the number of objects detected and the other 4 dimension are
(x1, y1, x2, y1). The second is an array of object detection confidence scores of
length (object) where object is the number of objects detected.
"""

def detect_video(self, frames: np.ndarray, pbar: bool = False):
pbar = tqdm if pbar else lambda x: x

# iterate over batches of 24
batch_size = 24
ejm714 marked this conversation as resolved.
Show resolved Hide resolved

video_outputs = []
with torch.no_grad():
for i in range(0, len(video_arr), batch_size):
a = video_arr[i : i + batch_size]
ejm714 marked this conversation as resolved.
Show resolved Hide resolved

outputs = self.model(
torch.from_numpy(self._preprocess_video(a)).to(self.config.device)
)
outputs = postprocess(
outputs, self.num_classes, self.config.confidence, self.config.nms_threshold
)
video_outputs.extend(outputs)

detections = []
for frame in pbar(frames):
detections.append(self.detect_image(frame))
for o in pbar(video_outputs):
detections.append(
self._process_frame_output(
o, original_height=video_arr.shape[1], original_width=video_arr.shape[2]
)
)

return detections

def detect_image(self, img_arr: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -149,10 +199,13 @@ def detect_image(self, img_arr: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.model(
torch.from_numpy(self._preprocess(img_arr)).unsqueeze(0).to(self.config.device)
)
output = postprocess(
outputs, self.num_classes, self.config.confidence, self.config.nms_threshold
)

return self._process_frame_output(output[0], img_arr.shape[0], img_arr.shape[1])

output = postprocess(
outputs, self.num_classes, self.config.confidence, self.config.nms_threshold
)[0]
def _process_frame_output(self, output, original_height, original_width):
if output is None:
return np.array([]), np.array([])
else:
Expand All @@ -162,7 +215,6 @@ def detect_image(self, img_arr: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
).assign(score=lambda row: row.score1 * row.score2)

# Transform bounding box to be in terms of the original image dimensions
original_height, original_width = img_arr.shape[:2]
ratio = min(
self.config.image_width / original_width,
self.config.image_height / original_height,
Expand Down
103 changes: 0 additions & 103 deletions zamba/object_detection/yolox/yolox_base.py

This file was deleted.

Loading