diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 4b24c11059e4..174c7313a796 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -205,9 +205,11 @@ def test_global_runtime_device_attributes(self): results = pjrt.run_multiprocess(self._global_runtime_device_attributes) for result in results.values(): for device in result: - self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys())) + self.assertCountEqual(['coords', 'core_on_chip', 'name'], + list(device.keys())) self.assertIsInstance(device['coords'], list) self.assertIsInstance(device['core_on_chip'], int) + self.assertIsInstance(device['name'], str) @staticmethod def _execute_time_metric():