-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
masks_to_bounding_boxes
op
#4290
Changes from 7 commits
cf51379
c67e035
3830dd1
926d444
f777416
cd46aa7
712131e
b6f5c42
b555c68
fc26f3a
c4d3045
4589951
6b19d67
c6c89ec
16a99a9
7115320
f4796d2
a070133
0131db3
0a23bcf
db8fb7b
f7a2c1e
c7dfcdf
5e6198a
7c78271
b9055c2
6c630c5
540c6a1
8e4fc2f
4c78297
140e429
8f2cd4a
7252723
26f68af
2c2d5dd
3a91957
e24805c
65404e9
6c89be7
b2a907c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
======================= | ||
Repurposing annotations | ||
======================= | ||
|
||
The following example illustrates the operations available in :ref:`the torchvision.ops module <ops>` for repurposing | ||
object localization annotations for different tasks (e.g. transforming masks used by instance and panoptic | ||
segmentation methods into bounding boxes used by object detection methods). | ||
""" | ||
|
||
from PIL import Image | ||
from pathlib import Path | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import torch | ||
import torchvision.transforms as T | ||
|
||
|
||
plt.rcParams["savefig.bbox"] = 'tight' | ||
orig_img = Image.open(Path('assets') / 'astronaut.jpg') | ||
# if you change the seed, make sure that the randomly-applied transforms | ||
# properly show that the image can be both transformed and *not* transformed! | ||
torch.manual_seed(0) | ||
|
||
|
||
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): | ||
if not isinstance(imgs[0], list): | ||
# Make a 2d grid even if there's just 1 row | ||
imgs = [imgs] | ||
|
||
num_rows = len(imgs) | ||
num_cols = len(imgs[0]) + with_orig | ||
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) | ||
for row_idx, row in enumerate(imgs): | ||
row = [orig_img] + row if with_orig else row | ||
for col_idx, img in enumerate(row): | ||
ax = axs[row_idx, col_idx] | ||
ax.imshow(np.asarray(img), **imshow_kwargs) | ||
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||
|
||
if with_orig: | ||
axs[0, 0].set(title='Original image') | ||
axs[0, 0].title.set_size(8) | ||
if row_title is not None: | ||
for row_idx in range(num_rows): | ||
axs[row_idx, 0].set(ylabel=row_title[row_idx]) | ||
|
||
plt.tight_layout() | ||
|
||
#################################### | ||
# Masks | ||
# -------------------------------------- | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, | ||
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape: | ||
# | ||
# (objects, height, width) | ||
# | ||
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly | ||
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape | ||
# of your masks annotation has the following shape: | ||
# | ||
# (4, 224, 224). | ||
# | ||
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object | ||
# localization tasks. | ||
# | ||
# Masks to bounding boxes | ||
# ~~~~~~~~~~~~~~~~~~~~~~~ | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be | ||
# used in methods like Faster RCNN and YOLO. | ||
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)] | ||
plot(padded_imgs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os.path | ||
|
||
import PIL.Image | ||
import numpy | ||
import pytest | ||
import torch | ||
|
||
import torchvision.ops | ||
|
||
ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | ||
|
||
|
||
@pytest.fixture | ||
def labeled_image() -> torch.Tensor: | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "labeled_image.png")) as image: | ||
return torch.tensor(numpy.array(image, numpy.int)) | ||
|
||
|
||
@pytest.fixture | ||
def masks() -> torch.Tensor: | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think it would be possible to write a test without the need for new images and hard-coded coordinates? Ideally, we could generate random masks and have a super simple version of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. I wrote about this elsewhere in the thread. I'd love to add a generator for various outputs similar to the function @goldsborough and I wrote for scikit-image ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug a friendly bump |
||
frames = numpy.zeros((image.n_frames, image.height, image.width), numpy.int) | ||
|
||
for index in range(image.n_frames): | ||
image.seek(index) | ||
|
||
frames[index] = numpy.array(image) | ||
|
||
return torch.tensor(frames) | ||
|
||
|
||
def test_masks_to_bounding_boxes(masks): | ||
expected = torch.tensor( | ||
[[ 127., 2., 165., 40. ], # noqa: E121, E201, E202, E241 | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
[ 4., 100., 88., 184. ], # noqa: E201, E202, E241 | ||
[ 168., 189., 294., 300. ], # noqa: E201, E202, E241 | ||
[ 556., 272., 700., 416. ], # noqa: E201, E202, E241 | ||
[ 800., 560., 990., 725. ], # noqa: E201, E202, E241 | ||
[ 294., 828., 594., 1092. ], # noqa: E201, E202, E241 | ||
[ 756., 1036., 1064., 1491. ]] # noqa: E201, E202, E241 | ||
) | ||
|
||
torch.testing.assert_close(torchvision.ops.masks_to_bounding_boxes(masks), expected) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
|
||
|
||
def masks_to_bounding_boxes(masks: torch.Tensor) -> torch.Tensor: | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Compute the bounding boxes around the provided masks | ||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. | ||
Returns a [N, 4] tensors, with the boxes in xyxy format | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if masks.numel() == 0: | ||
return torch.zeros((0, 4), device=masks.device) | ||
|
||
h, w = masks.shape[-2:] | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
y = torch.arange(0, h, dtype=torch.float) | ||
x = torch.arange(0, w, dtype=torch.float) | ||
y, x = torch.meshgrid(y, x) | ||
|
||
x_mask = masks * x.unsqueeze(0) | ||
x_max = x_mask.flatten(1).max(-1)[0] | ||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] | ||
|
||
y_mask = masks * y.unsqueeze(0) | ||
y_max = y_mask.flatten(1).max(-1)[0] | ||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] | ||
|
||
return torch.stack([x_min, y_min, x_max, y_max], 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some debugging I found out the reason for
build_docs
CI failure. The problem istorchvision.ops
does not have a nice index on right side (basically a html link to #ops like transforms has). This causes CI failure.We need to remove the ref, and it will work fine. This is slightly hacky fix, but works fine.
I tried running it locally. I could build the gallery example. It looks nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I appreciate the debugging.