Skip to content

Commit

Permalink
Port test_onnx.py to pytest (#4047)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang authored Jun 11, 2021
1 parent 552a406 commit fb2598b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ jobs:
pip install --user --progress-bar off --editable .
pip install --user onnx
pip install --user onnxruntime
pip install --user pytest
python test/test_onnx.py
binary_linux_wheel:
Expand Down
1 change: 1 addition & 0 deletions .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ jobs:
pip install --user --progress-bar off --editable .
pip install --user onnx
pip install --user onnxruntime
pip install --user pytest
python test/test_onnx.py

binary_linux_wheel:
Expand Down
16 changes: 7 additions & 9 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredictor

from collections import OrderedDict

import unittest
import pytest
from torchvision.ops._register_onnx_ops import _onnx_opset_version


@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
class ONNXExporterTester(unittest.TestCase):
@pytest.mark.skipif(onnxruntime is None, reason='ONNX Runtime unavailable')
class TestONNXExporter:
@classmethod
def setUpClass(cls):
def setup_class(cls):
torch.manual_seed(123)

def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
Expand Down Expand Up @@ -80,7 +78,7 @@ def to_numpy(tensor):
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
self.assertIn("(0.00%)", str(error), str(error))
assert "(0.00%)" in str(error), str(error)
else:
raise

Expand Down Expand Up @@ -161,7 +159,7 @@ def test_roi_align_aligned(self):
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])

@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
def test_roi_align_malformed_boxes(self):
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
Expand Down Expand Up @@ -527,4 +525,4 @@ def test_shufflenet_v2_dynamic_axes(self):


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])

0 comments on commit fb2598b

Please sign in to comment.