From 8442f0310e3d36ea49b17d9440cbd23442b46317 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 25 Jul 2024 17:43:06 -0700 Subject: [PATCH] Fixing cuda device check (#536) Summary: Previous cuda device check is not general enough, this adds a better check that works for more cases like "cuda:0" Test Plan: python test/quantization/test_quant_api.py -k test_int4wo_quantized_model_to_device Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 24 ++++++++++++----------- torchao/dtypes/affine_quantized_tensor.py | 3 ++- torchao/dtypes/utils.py | 5 ++++- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ab24fc981c..c19dc2660b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -642,17 +642,19 @@ def test_int8wo_quantized_model_to_device(self): @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+") def test_int4wo_quantized_model_to_device(self): # TODO: change initial model to "cpu" - m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") - - quantize_(m, int4_weight_only()) - ref = m(*example_inputs) - - example_inputs_cuda = (example_inputs[0].to("cuda"),) - m.to(device="cuda") - cuda_res = m(*example_inputs_cuda) - self.assertEqual(cuda_res.cpu(), ref) + devices = ["cuda", "cuda:0"] + for device in devices: + m = ToyLinearModel().eval().to(torch.bfloat16).to(device) + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) + + quantize_(m, int4_weight_only()) + ref = m(*example_inputs) + + example_inputs_cuda = (example_inputs[0].to(device),) + m.to(device=device) + cuda_res = m(*example_inputs_cuda) + self.assertEqual(cuda_res.cpu(), ref) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index da5cc7d28b..4142937905 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -21,6 +21,7 @@ _register_layout_cls, _get_layout_tensor_constructor, LayoutType, + is_device, ) from typing import ClassVar from dataclasses import dataclass @@ -544,7 +545,7 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"): + if not is_device("cuda", device): raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device), diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 3a437b4745..656c4873ab 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable +from typing import Dict, Callable, Union from collections import defaultdict import functools from dataclasses import dataclass @@ -89,3 +89,6 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] + +def is_device(target_device_str: str, device: Union[str, torch.device]): + return torch.device(device).type == target_device_str