forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_anchor_box.py
88 lines (72 loc) · 3.83 KB
/
test_anchor_box.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
import torch
from parameterized import parameterized
from monai.apps.detection.utils.anchor_utils import AnchorGenerator, AnchorGeneratorWithAnchorShape
from monai.utils import optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
_, has_torchvision = optional_import("torchvision")
TEST_CASES_2D = [
[
{"sizes": ((10, 12, 14, 16), (20, 24, 28, 32)), "aspect_ratios": ((1.0, 0.5, 2.0), (1.0, 0.5, 2.0))},
(5, 3, 128, 128),
((5, 7, 64, 32), (5, 7, 32, 16)),
]
]
TEST_CASES_SHAPE_3D = [
[
{"feature_map_scales": (1, 2), "base_anchor_shapes": ((4, 3, 6), (8, 2, 4))},
(5, 3, 128, 128, 128),
((5, 7, 64, 32, 32), (5, 7, 32, 16, 16)),
]
]
@SkipIfBeforePyTorchVersion((1, 11))
@unittest.skipUnless(has_torchvision, "Requires torchvision")
class TestAnchorGenerator(unittest.TestCase):
@parameterized.expand(TEST_CASES_2D)
def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils")
image_list, _ = optional_import("torchvision.models.detection.image_list")
# test it behaves the same with torchvision for 2d
anchor = AnchorGenerator(**input_param, indexing="xy")
anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param)
for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors):
assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)
for a, a_f in zip(anchor.num_anchors_per_location(), anchor_ref.num_anchors_per_location()):
assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)
grid_sizes = [[2, 2], [1, 1]]
strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]]
for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)):
assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)
images = torch.rand(image_shape)
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
result = anchor(images, feature_maps)
result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps)
for a, a_f in zip(result, result_ref):
assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1)
@parameterized.expand(TEST_CASES_2D)
def test_script_2d(self, input_param, image_shape, feature_maps_shapes):
# test whether support torchscript
anchor = AnchorGenerator(**input_param, indexing="xy")
images = torch.rand(image_shape)
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
test_script_save(anchor, images, feature_maps)
@parameterized.expand(TEST_CASES_SHAPE_3D)
def test_script_3d(self, input_param, image_shape, feature_maps_shapes):
# test whether support torchscript
anchor = AnchorGeneratorWithAnchorShape(**input_param, indexing="ij")
images = torch.rand(image_shape)
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
test_script_save(anchor, images, feature_maps)
if __name__ == "__main__":
unittest.main()