Skip to content

Commit

Permalink
[fbsync] Use torch.testing.assert_close in test_anchor_utils.py (#3880)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Philip Meier <github.pmeier@posteo.de>

Reviewed By: vincentqb, cpuhrsch

Differential Revision: D28679964

fbshipit-source-id: c53eb63031a4b965e870f3334d749e4d9d41ddad
  • Loading branch information
datumbox authored and facebook-github-bot committed May 25, 2021
1 parent f654722 commit b243ede
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions test/test_models_detection_anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from common_utils import TestCase
from _assert_utils import assert_equal
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
from torchvision.models.detection.image_list import ImageList

Expand Down Expand Up @@ -62,8 +63,8 @@ def test_anchor_generator(self):
self.assertEqual(len(anchors), 2)
self.assertEqual(tuple(anchors[0].shape), (9, 4))
self.assertEqual(tuple(anchors[1].shape), (9, 4))
self.assertEqual(anchors[0], anchors_output)
self.assertEqual(anchors[1], anchors_output)
assert_equal(anchors[0], anchors_output)
assert_equal(anchors[1], anchors_output)

def test_defaultbox_generator(self):
images = torch.zeros(2, 3, 15, 15)
Expand All @@ -85,5 +86,5 @@ def test_defaultbox_generator(self):
self.assertEqual(len(dboxes), 2)
self.assertEqual(tuple(dboxes[0].shape), (4, 4))
self.assertEqual(tuple(dboxes[1].shape), (4, 4))
self.assertTrue(dboxes[0].allclose(dboxes_output))
self.assertTrue(dboxes[1].allclose(dboxes_output))
torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)

0 comments on commit b243ede

Please sign in to comment.