-
Notifications
You must be signed in to change notification settings - Fork 1
/
transforms.py
70 lines (54 loc) · 2.02 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Transforms
Fred Zhang <frederic.zhang@anu.edu.au>
The Australian National University
Australian Centre for Robotic Vision
"""
from torch import nn
from torchvision.models.detection import transform
class HOINetworkTransform(transform.GeneralizedRCNNTransform):
"""
Transformations for input image and target (box pairs)
Arguments(Positional):
min_size(int)
max_size(int)
image_mean(list[float] or tuple[float])
image_std(list[float] or tuple[float])
Refer to torchvision.models.detection for more details
"""
def __init__(self, *args):
super().__init__(*args)
def resize(self, image, target):
"""
Override method to resize box pairs
"""
h, w = image.shape[-2:]
min_size = float(min(image.shape[-2:]))
max_size = float(max(image.shape[-2:]))
scale_factor = min(
self.min_size[0] / min_size,
self.max_size / max_size
)
image = nn.functional.interpolate(
image[None], scale_factor=scale_factor,
mode='bilinear', align_corners=False,
recompute_scale_factor=True
)[0]
if target is None:
return image, target
target['boxes_h'] = transform.resize_boxes(target['boxes_h'],
(h, w), image.shape[-2:])
target['boxes_o'] = transform.resize_boxes(target['boxes_o'],
(h, w), image.shape[-2:])
return image, target
def postprocess(self, results, image_shapes, original_image_sizes):
if self.training:
loss = results.pop()
for pred, im_s, o_im_s in zip(results, image_shapes, original_image_sizes):
boxes_h, boxes_o = pred['boxes_h'], pred['boxes_o']
boxes_h = transform.resize_boxes(boxes_h, im_s, o_im_s)
boxes_o = transform.resize_boxes(boxes_o, im_s, o_im_s)
pred['boxes_h'], pred['boxes_o'] = boxes_h, boxes_o
if self.training:
results.append(loss)
return results