diff --git a/py/trtorch/Device.py b/py/trtorch/Device.py index 58e6706d4d..3e408bd951 100644 --- a/py/trtorch/Device.py +++ b/py/trtorch/Device.py @@ -1,7 +1,7 @@ import torch from trtorch import _types -import logging +import trtorch.logging import trtorch._C import warnings @@ -54,11 +54,11 @@ def __init__(self, *args, **kwargs): else: self.dla_core = id self.gpu_id = 0 - logging.log(logging.log.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") + trtorch.logging.log(trtorch.logging.Level.Warning, + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") elif len(args) == 0: - if not "gpu_id" in kwargs or not "dla_core" in kwargs: + if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: self.device_type = _types.DeviceType.DLA self.dla_core = kwargs["dla_core"] @@ -66,11 +66,15 @@ def __init__(self, *args, **kwargs): self.gpu_id = kwargs["gpu_id"] else: self.gpu_id = 0 - logging.log(logging.log.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") + trtorch.logging.log(trtorch.logging.Level.Warning, + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") else: self.gpu_id = kwargs["gpu_id"] - self.device_type == _types.DeviceType.GPU + self.device_type = _types.DeviceType.GPU + else: + raise ValueError( + "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" + ) else: raise ValueError( @@ -80,6 +84,7 @@ def __init__(self, *args, **kwargs): if "allow_gpu_fallback" in kwargs: if not isinstance(kwargs["allow_gpu_fallback"], bool): raise TypeError("allow_gpu_fallback must be a bool") + self.allow_gpu_fallback = kwargs["allow_gpu_fallback"] def __str__(self) -> str: return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \ diff --git a/tests/py/test_api.py b/tests/py/test_api.py index 94239b1475..da98ceae29 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -230,6 +230,53 @@ def test_is_colored_output_on(self): self.assertTrue(color) +class TestDevice(unittest.TestCase): + + def test_from_string_constructor(self): + device = trtorch.Device("cuda:0") + self.assertEqual(device.device_type, trtorch.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + device = trtorch.Device("gpu:1") + self.assertEqual(device.device_type, trtorch.DeviceType.GPU) + self.assertEqual(device.gpu_id, 1) + + def test_from_string_constructor_dla(self): + device = trtorch.Device("dla:0") + self.assertEqual(device.device_type, trtorch.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 0) + + device = trtorch.Device("dla:1", allow_gpu_fallback=True) + self.assertEqual(device.device_type, trtorch.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 1) + self.assertEqual(device.allow_gpu_fallback, True) + + def test_kwargs_gpu(self): + device = trtorch.Device(gpu_id=0) + self.assertEqual(device.device_type, trtorch.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + def test_kwargs_dla_and_settings(self): + device = trtorch.Device(dla_core=1, allow_gpu_fallback=False) + self.assertEqual(device.device_type, trtorch.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 1) + self.assertEqual(device.allow_gpu_fallback, False) + + device = trtorch.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True) + self.assertEqual(device.device_type, trtorch.DeviceType.DLA) + self.assertEqual(device.gpu_id, 1) + self.assertEqual(device.dla_core, 0) + self.assertEqual(device.allow_gpu_fallback, True) + + def test_from_torch(self): + device = trtorch.Device._from_torch_device(torch.device("cuda:0")) + self.assertEqual(device.device_type, trtorch.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + def test_suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(TestLoggingAPIs)) @@ -242,6 +289,7 @@ def test_suite(): suite.addTest( TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True))) suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport)) + suite.addTest(unittest.makeSuite(TestDevice)) return suite