Skip to content

Commit

Permalink
Adds torchscript Compatibility to box_convert (pytorch#2737)
Browse files Browse the repository at this point in the history
* fixies small bug in box_convert

* activates jit test

* Passes JIT test

* fixes typo

* adds error tests, removes assert

* implements to proposal2
  • Loading branch information
oke-aditya authored and bryant1410 committed Nov 22, 2020
1 parent f658c5d commit 386028d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 40 deletions.
32 changes: 21 additions & 11 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,20 +727,30 @@ def test_bbox_xywh_cxcywh(self):
self.assertEqual(box_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, box_tensor)).item()

# def test_bbox_convert_jit(self):
# box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
# [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
def test_bbox_invalid(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)

invalid_infmts = ["xwyh", "cxwyh"]
invalid_outfmts = ["xwcx", "xhwcy"]
for inv_infmt in invalid_infmts:
for inv_outfmt in invalid_outfmts:
self.assertRaises(ValueError, ops.box_convert, box_tensor, inv_infmt, inv_outfmt)

def test_bbox_convert_jit(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

# scripted_fn = torch.jit.script(ops.box_convert)
# TOLERANCE = 1e-3
scripted_fn = torch.jit.script(ops.box_convert)
TOLERANCE = 1e-3

# box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
# scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
# self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)

# box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
# scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
# self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)


class BoxAreaTester(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions torchvision/ops/_box_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor:
boxes (Tensor[N, 4]): boxes in (x, y, w, h) format.
"""
x1, y1, x2, y2 = boxes.unbind(-1)
x2 = x2 - x1 # x2 - x1
y2 = y2 - y1 # y2 - y1
boxes = torch.stack((x1, y1, x2, y2), dim=-1)
w = x2 - x1 # x2 - x1
h = y2 - y1 # y2 - y1
boxes = torch.stack((x1, y1, w, h), dim=-1)
return boxes
46 changes: 20 additions & 26 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,39 +154,33 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
Returns:
boxes (Tensor[N, 4]): Boxes into converted format.
"""

allowed_fmts = ("xyxy", "xywh", "cxcywh")
assert in_fmt in allowed_fmts
assert out_fmt in allowed_fmts
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")

if in_fmt == out_fmt:
boxes_converted = boxes.clone()
return boxes_converted
return boxes.clone()

if in_fmt != 'xyxy' and out_fmt != 'xyxy':
# convert to xyxy and change in_fmt xyxy
if in_fmt == "xywh":
boxes_xyxy = _box_xywh_to_xyxy(boxes)
if out_fmt == "cxcywh":
boxes_converted = _box_xyxy_to_cxcywh(boxes_xyxy)

boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes_xyxy = _box_cxcywh_to_xyxy(boxes)
if out_fmt == "xywh":
boxes_converted = _box_xyxy_to_xywh(boxes_xyxy)

# convert one to xyxy and change either in_fmt or out_fmt to xyxy
else:
if in_fmt == "xyxy":
if out_fmt == "xywh":
boxes_converted = _box_xyxy_to_xywh(boxes)
elif out_fmt == "cxcywh":
boxes_converted = _box_xyxy_to_cxcywh(boxes)
elif out_fmt == "xyxy":
if in_fmt == "xywh":
boxes_converted = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes_converted = _box_cxcywh_to_xyxy(boxes)

return boxes_converted
boxes = _box_cxcywh_to_xyxy(boxes)
in_fmt = 'xyxy'

if in_fmt == "xyxy":
if out_fmt == "xywh":
boxes = _box_xyxy_to_xywh(boxes)
elif out_fmt == "cxcywh":
boxes = _box_xyxy_to_cxcywh(boxes)
elif out_fmt == "xyxy":
if in_fmt == "xywh":
boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes = _box_cxcywh_to_xyxy(boxes)
return boxes


def box_area(boxes: Tensor) -> Tensor:
Expand Down

0 comments on commit 386028d

Please sign in to comment.