Skip to content

Commit

Permalink
Fixing cuda device check (pytorch#536)
Browse files Browse the repository at this point in the history
Summary:
Previous cuda device check is not general enough, this adds a better check that works for
more cases like "cuda:0"

Test Plan:
python test/quantization/test_quant_api.py -k test_int4wo_quantized_model_to_device

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 26, 2024
1 parent 2a01aaa commit 8442f03
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
24 changes: 13 additions & 11 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,17 +642,19 @@ def test_int8wo_quantized_model_to_device(self):
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
def test_int4wo_quantized_model_to_device(self):
# TODO: change initial model to "cpu"
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to("cuda"),)
m.to(device="cuda")
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)
devices = ["cuda", "cuda:0"]
for device in devices:
m = ToyLinearModel().eval().to(torch.bfloat16).to(device)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to(device),)
m.to(device=device)
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
3 changes: 2 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
is_device,
)
from typing import ClassVar
from dataclasses import dataclass
Expand Down Expand Up @@ -544,7 +545,7 @@ def from_plain(
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs["device"]
if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"):
if not is_device("cuda", device):
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}")
return self.__class__(
self.packed_weight.to(device),
Expand Down
5 changes: 4 additions & 1 deletion torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Dict, Callable
from typing import Dict, Callable, Union
from collections import defaultdict
import functools
from dataclasses import dataclass
Expand Down Expand Up @@ -89,3 +89,6 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout
raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}")

return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class]

def is_device(target_device_str: str, device: Union[str, torch.device]):
return torch.device(device).type == target_device_str

0 comments on commit 8442f03

Please sign in to comment.