Skip to content

Commit

Permalink
Move some util functions from quantization.utils to torchao.utils (#337)
Browse files Browse the repository at this point in the history
Summary:

Moved
```
TORCH_VERSION_AFTER_2_(2/3/4)
get_model_size_in_bytes
unwrap_tensor_subclass
```

from quantization/utils.py to torchao/utils.py

Test Plan:
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jun 7, 2024
1 parent 335171f commit 000a0fd
Show file tree
Hide file tree
Showing 27 changed files with 129 additions and 125 deletions.
2 changes: 1 addition & 1 deletion test/dtypes/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
parametrize,
run_tests,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

try:
from torchao.prototype.fp8 import gemm_split_k, to_float8
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)

from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down
3 changes: 2 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
swap_linear_with_mx_linear,
)

from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AFTER_2_4

# trying to outsmart flake8
__has_cuda = torch.cuda.is_available()
Expand Down
3 changes: 2 additions & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
to_dtype,
)

from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AFTER_2_4

# trying to outsmart flake8
__has_cuda = torch.cuda.is_available()
Expand Down
10 changes: 5 additions & 5 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchao.prototype.common.bitpacking import pack, unpack
import pytest
from torch.utils._triton import has_triton
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand All @@ -20,15 +20,15 @@ def test_uint3_to_int16_col_wise_cpu():
unpacked = unpack(packed, 3, False, device='cpu')
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint4_to_uint8():
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4)
unpacked = unpack(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_uint4_to_uint8_compile():
Expand All @@ -40,7 +40,7 @@ def test_uint4_to_uint8_compile():
unpacked = unpack_compiled(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint3_to_int16():
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda()
Expand All @@ -67,4 +67,4 @@ def test_uint3_to_int16_col_wise():
packed = pack(test_tensor,16, 3, False)
unpacked = unpack(packed, 3, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
assert(unpadded.allclose(test_tensor))
2 changes: 1 addition & 1 deletion test/quantization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torchao.quantization.utils import find_multiple
from torchao.utils import find_multiple

def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
)
from torchao.quantization.utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down Expand Up @@ -556,7 +556,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
self.assertTrue(torch.equal(res, ref))

# workaround for export path
from torchao.quantization.utils import unwrap_tensor_subclass
from torchao.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)

m = torch.export.export(m_unwrapped, example_inputs).module()
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MappingType,
)

from torchao.quantization.utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down
4 changes: 2 additions & 2 deletions test/sparsity/test_fast_sparse_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
swap_semi_sparse_linear_with_linear,
SemiSparseLinear
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

class TestModel(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -42,7 +42,7 @@ def test_runtime_weight_sparsification(self):
if isinstance(mod, torch.nn.Linear):
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense()
mod.weight = nn.Parameter(sparse)

dense_result = model(input)

# map from fqn to replacement linear module
Expand Down
2 changes: 1 addition & 1 deletion test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_get_subclass_inserter,
_is_linear,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torch.testing._internal.common_utils import TestCase


Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
import torchao
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4
import unittest
from parameterized import parameterized
import pytest
Expand Down
10 changes: 5 additions & 5 deletions torchao/_executorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs):
torch.ops.quantized_decomposed.quantize_per_channel_group is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.")
Expand All @@ -23,7 +23,7 @@ def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **k
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.")
Expand All @@ -37,7 +37,7 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs):
torch.ops.quantized_decomposed.dequantize_per_channel_group is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.")
Expand All @@ -51,7 +51,7 @@ def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs):
torch.ops.quantized_decomposed.quantize_per_token is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.")
Expand All @@ -65,7 +65,7 @@ def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs):
torch.ops.quantized_decomposed.dequantize_per_token is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.")
2 changes: 1 addition & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch

from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2
from torchao.utils import TORCH_VERSION_AFTER_2_2

try:
# Only works for torch2.2 or newer.
Expand Down
2 changes: 1 addition & 1 deletion torchao/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4


def register_custom_op(name):
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.utils._triton import has_triton

from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

# TODO(future): if needed, make the below work on previous PyTorch versions,
# just need to hunt down the previous location of `libdevice`. An assert
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from .utils import (
_lm_eval_available,
_MultiInput,
TORCH_VERSION_AFTER_2_3,
)
from torchao.utils import (
find_multiple,
)
from torchao.utils import TORCH_VERSION_AFTER_2_3
from typing import Any, Dict, Optional
from .unified import Quantizer

Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"compute_error",
"get_model_size_in_bytes",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
quantize_activation_per_token_absmax,
safe_int_mm,
)
from .utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4
import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Any, Callable

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from .utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
groupwise_affine_quantize_tensor_from_qparams,
MappingType,
)
from .utils import find_multiple
from torchao.utils import find_multiple
from typing import Tuple, Optional, Callable, Dict, Any


Expand Down
Loading

0 comments on commit 000a0fd

Please sign in to comment.