diff --git a/test/dtypes/test_fp8.py b/test/dtypes/test_fp8.py index 811de3a4c3..ae008fc91e 100644 --- a/test/dtypes/test_fp8.py +++ b/test/dtypes/test_fp8.py @@ -7,7 +7,7 @@ parametrize, run_tests, ) -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 try: from torchao.prototype.fp8 import gemm_split_k, to_float8 diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c770f455fe..dc30e39b8f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -70,7 +70,7 @@ from parameterized import parameterized import itertools import logging -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 logger = logging.getLogger("INFO") diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 892d5b57f7..d247c70881 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -44,7 +44,7 @@ ) from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 65f6002dbf..c453b0fe38 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -19,7 +19,8 @@ swap_linear_with_mx_linear, ) -from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import compute_error +from torchao.utils import TORCH_VERSION_AFTER_2_4 # trying to outsmart flake8 __has_cuda = torch.cuda.is_available() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index f1b82e376b..a311f0f050 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -23,7 +23,8 @@ to_dtype, ) -from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import compute_error +from torchao.utils import TORCH_VERSION_AFTER_2_4 # trying to outsmart flake8 __has_cuda = torch.cuda.is_available() diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index c1b60e07f8..d1c1d261d1 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -2,7 +2,7 @@ from torchao.prototype.common.bitpacking import pack, unpack import pytest from torch.utils._triton import has_triton -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -20,7 +20,7 @@ def test_uint3_to_int16_col_wise_cpu(): unpacked = unpack(packed, 3, False, device='cpu') unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) - + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint4_to_uint8(): test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() @@ -28,7 +28,7 @@ def test_uint4_to_uint8(): unpacked = unpack(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) - + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") def test_uint4_to_uint8_compile(): @@ -40,7 +40,7 @@ def test_uint4_to_uint8_compile(): unpacked = unpack_compiled(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) - + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint3_to_int16(): test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() @@ -67,4 +67,4 @@ def test_uint3_to_int16_col_wise(): packed = pack(test_tensor,16, 3, False) unpacked = unpack(packed, 3, False) unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) \ No newline at end of file + assert(unpadded.allclose(test_tensor)) diff --git a/test/quantization/model.py b/test/quantization/model.py index e851901c41..94835fc0c3 100644 --- a/test/quantization/model.py +++ b/test/quantization/model.py @@ -10,7 +10,7 @@ import torch.nn as nn from torch import Tensor from torch.nn import functional as F -from torchao.quantization.utils import find_multiple +from torchao.utils import find_multiple def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 93323df0f1..f5be66f50a 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -19,7 +19,7 @@ fake_quantize_per_token, ) from torchao.quantization.quant_primitives import get_group_qparams_symmetric -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 # TODO: put this in a common test utils file diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 68df4f29fe..8a96124f1c 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -44,7 +44,7 @@ get_apply_int8wo_quant, get_apply_int8dyn_quant, ) -from torchao.quantization.utils import ( +from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, ) @@ -556,7 +556,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): self.assertTrue(torch.equal(res, ref)) # workaround for export path - from torchao.quantization.utils import unwrap_tensor_subclass + from torchao.utils import unwrap_tensor_subclass m_unwrapped = unwrap_tensor_subclass(m) m = torch.export.export(m_unwrapped, example_inputs).module() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 6054c6e66f..8cecdf32ea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -19,7 +19,7 @@ MappingType, ) -from torchao.quantization.utils import ( +from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, ) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index b195534664..081f0e4d2f 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -12,7 +12,7 @@ swap_semi_sparse_linear_with_linear, SemiSparseLinear ) -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 class TestModel(nn.Module): def __init__(self): @@ -42,7 +42,7 @@ def test_runtime_weight_sparsification(self): if isinstance(mod, torch.nn.Linear): sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() mod.weight = nn.Parameter(sparse) - + dense_result = model(input) # map from fqn to replacement linear module diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 83c0544f6e..c7bc2700df 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -11,7 +11,7 @@ _get_subclass_inserter, _is_linear, ) -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3 from torch.testing._internal.common_utils import TestCase diff --git a/test/test_ops.py b/test/test_ops.py index b20e029380..cd833359eb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,7 @@ from torch.testing._internal.common_utils import TestCase, IS_FBCODE from torch.testing._internal.optests import opcheck import torchao -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 import unittest from parameterized import parameterized import pytest diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 33444b55fd..3ec2506ea6 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -9,7 +9,7 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): torch.ops.quantized_decomposed.quantize_per_channel_group is only available in PyTorch 2.3+ and recently changed signatures. """ - from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 + from torchao.utils import TORCH_VERSION_AFTER_2_3 if TORCH_VERSION_AFTER_2_3: return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.") @@ -23,7 +23,7 @@ def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **k torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available in PyTorch 2.3+ and recently changed signatures. """ - from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 + from torchao.utils import TORCH_VERSION_AFTER_2_3 if TORCH_VERSION_AFTER_2_3: return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs) raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.") @@ -37,7 +37,7 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): torch.ops.quantized_decomposed.dequantize_per_channel_group is only available in PyTorch 2.3+ and recently changed signatures. """ - from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 + from torchao.utils import TORCH_VERSION_AFTER_2_3 if TORCH_VERSION_AFTER_2_3: return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.") @@ -51,7 +51,7 @@ def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): torch.ops.quantized_decomposed.quantize_per_token is only available in PyTorch 2.3+ and recently changed signatures. """ - from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 + from torchao.utils import TORCH_VERSION_AFTER_2_3 if TORCH_VERSION_AFTER_2_3: return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.") @@ -65,7 +65,7 @@ def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): torch.ops.quantized_decomposed.dequantize_per_token is only available in PyTorch 2.3+ and recently changed signatures. """ - from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 + from torchao.utils import TORCH_VERSION_AFTER_2_3 if TORCH_VERSION_AFTER_2_3: return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.") diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 8491a2ba6c..28827c543d 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -2,7 +2,7 @@ import os import torch -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2 +from torchao.utils import TORCH_VERSION_AFTER_2_2 try: # Only works for torch2.2 or newer. diff --git a/torchao/ops.py b/torchao/ops.py index 7fce2de22f..51adb24100 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 def register_custom_op(name): diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 60aaa336ba..91aea9275a 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -11,7 +11,7 @@ import torch from torch.utils._triton import has_triton -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 2e23767a84..6c3f41b834 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -22,9 +22,11 @@ from .utils import ( _lm_eval_available, _MultiInput, - TORCH_VERSION_AFTER_2_3, +) +from torchao.utils import ( find_multiple, ) +from torchao.utils import TORCH_VERSION_AFTER_2_3 from typing import Any, Dict, Optional from .unified import Quantizer diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ab51dbb3a5..aa265daaf5 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -44,7 +44,6 @@ "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", "compute_error", - "get_model_size_in_bytes", "WeightOnlyInt8QuantLinear", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 9eeb146f55..6ad60e042f 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ quantize_activation_per_token_absmax, safe_int_mm, ) -from .utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4 import torch.nn.functional as F try: from torch._inductor.utils import do_bench diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 84e8c151c2..510db85512 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -25,7 +25,7 @@ from typing import Any, Callable from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import ( +from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index d1ad1e7403..5f5ba39d66 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -14,7 +14,7 @@ from torchao.kernel.intmm import int_scaled_matmul from torchao.kernel.intmm import safe_int_mm -from .utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3 __all__ = [ diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 972699f0bf..75c68cdf82 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -18,7 +18,7 @@ groupwise_affine_quantize_tensor_from_qparams, MappingType, ) -from .utils import find_multiple +from torchao.utils import find_multiple from typing import Tuple, Optional, Callable, Dict, Any diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 1a08f9901a..355be5045e 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -7,20 +7,10 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode -from packaging import version -import torch.nn.utils.parametrize as parametrize -from torchao.utils import find_multiple - __all__ = [ - "find_multiple", "compute_error", "_apply_logging_hook", - "get_model_size_in_bytes", - "unwrap_tensor_subclass", - "TORCH_VERSION_AFTER_2_2", - "TORCH_VERSION_AFTER_2_3", - "TORCH_VERSION_AFTER_2_4", ] try: @@ -87,67 +77,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return rs - -class UnwrapTensorSubclass(torch.nn.Module): - def forward(self, *tensors): - todo = list(tensors) - for tp, meta, inner_tensors in reversed(self.rebuild_stack): - nb_tensor = len(inner_tensors) - inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} - todo = todo[nb_tensor:] - rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None) - todo.append(rebuilt) - - assert len(todo) == 1 - return todo[0] - - def right_inverse(self, tensor): - assert type(tensor) is not torch.Tensor - rebuild_stack = [] - plain_tensors = [] - todo = [tensor] - while todo: - obj = todo.pop() - inner_tensors, metadata = obj.__tensor_flatten__() - rebuild_stack.append((type(obj), metadata, inner_tensors)) - for attr_name in inner_tensors: - val = getattr(obj, attr_name) - if type(val) is torch.Tensor: - plain_tensors.append(val) - else: - assert isinstance(val, torch.Tensor) - todo.append(val) - - self.rebuild_stack = rebuild_stack - - return plain_tensors - -def unwrap_tensor_subclass(model, filter_fn=None): - for name, child in model.named_children(): - # make sure child.weight is a tensor subclass - if ( - isinstance(child, torch.nn.Linear) and - hasattr(child, "weight") and - type(child.weight) is not torch.Tensor and - type(child.weight) is not torch.nn.Parameter and - isinstance(child.weight, torch.Tensor) and - issubclass(type(child.weight), torch.Tensor) - ): - parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) - unwrap_tensor_subclass(child) - return model - - -# https://discuss.pytorch.org/t/finding-model-size/130275 -def get_model_size_in_bytes(model): - s = 0 - for p in model.parameters(): - s += p.nelement() * p.element_size() - for b in model.buffers(): - s += b.nelement() * b.element_size() - return s - - class _MultiInput: def __init__(self, inputs): @@ -165,20 +94,3 @@ def cuda(self): self.values = [ val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values ] - - -# TODO: quantization namespace is not the right place ot have this -if version.parse(torch.__version__) >= version.parse("2.4.0.dev"): - TORCH_VERSION_AFTER_2_4 = True -else: - TORCH_VERSION_AFTER_2_4 = False - -if version.parse(torch.__version__) >= version.parse("2.3.0.dev"): - TORCH_VERSION_AFTER_2_3 = True -else: - TORCH_VERSION_AFTER_2_3 = False - -if version.parse(torch.__version__) >= version.parse("2.2.0.dev"): - TORCH_VERSION_AFTER_2_2 = True -else: - TORCH_VERSION_AFTER_2_2 = False diff --git a/torchao/sparsity/training/__init__.py b/torchao/sparsity/training/__init__.py index 16035fe62b..044f6d7515 100644 --- a/torchao/sparsity/training/__init__.py +++ b/torchao/sparsity/training/__init__.py @@ -7,7 +7,7 @@ from torchao.sparsity.training.autograd import semi_structured_sparsify from torchao.sparsity.training.pointwise_ops import CUTLASS_POINTWISE_OP_DISPATCH_TABLE -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3 # load pointwise op support, which exists only for CUTLASS if TORCH_VERSION_AFTER_2_3: diff --git a/torchao/sparsity/training/autograd.py b/torchao/sparsity/training/autograd.py index 8e22cad9fb..e920b72859 100644 --- a/torchao/sparsity/training/autograd.py +++ b/torchao/sparsity/training/autograd.py @@ -2,7 +2,7 @@ import torch from torch.sparse import SparseSemiStructuredTensor -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3 if TORCH_VERSION_AFTER_2_3: from torch.sparse import SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT @@ -120,7 +120,7 @@ def semi_structured_sparsify( backend: str = "cutlass", ) -> SparseSemiStructuredTensor: """ - Sparsifies a dense tensor into a semi-structured tensor, according to the algo and backend passed. + Sparsifies a dense tensor into a semi-structured tensor, according to the algo and backend passed. """ return _SparsifyFunc.apply(x, algo, backend) @@ -131,6 +131,6 @@ def semi_structured_sparsify_like( gradient: GRADIENT_TYPE = GRADIENT_TYPE.SPARSE, ) -> SparseSemiStructuredTensor: """ - Sparsifies a dense tensor into a semi-structured tensor, using the mask of the provided pattern. + Sparsifies a dense tensor into a semi-structured tensor, using the mask of the provided pattern. """ return _SparsifyLikeFunc.apply(x, pattern, gradient) diff --git a/torchao/utils.py b/torchao/utils.py index 0a3fe5ba97..27650dae1c 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -3,6 +3,22 @@ from typing import Tuple from functools import reduce from math import gcd +from packaging import version +import torch.nn.utils.parametrize as parametrize + +__all__ = [ + "benchmark_model", + "profiler_runner", + "get_compute_capability", + "skip_if_compute_capability_less_than", + "benchmark_torch_function_in_microseconds", + "find_multiple", + "get_model_size_in_bytes", + "unwrap_tensor_subclass", + "TORCH_VERSION_AFTER_2_2", + "TORCH_VERSION_AFTER_2_3", + "TORCH_VERSION_AFTER_2_4", +] def benchmark_model(model, num_runs, input_tensor): @@ -65,3 +81,76 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: if n % k == 0: return n return n + k - (n % k) + +# https://discuss.pytorch.org/t/finding-model-size/130275 +def get_model_size_in_bytes(model): + s = 0 + for p in model.parameters(): + s += p.nelement() * p.element_size() + for b in model.buffers(): + s += b.nelement() * b.element_size() + return s + +class UnwrapTensorSubclass(torch.nn.Module): + def forward(self, *tensors): + todo = list(tensors) + for tp, meta, inner_tensors in reversed(self.rebuild_stack): + nb_tensor = len(inner_tensors) + inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} + todo = todo[nb_tensor:] + rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None) + todo.append(rebuilt) + + assert len(todo) == 1 + return todo[0] + + def right_inverse(self, tensor): + assert type(tensor) is not torch.Tensor + rebuild_stack = [] + plain_tensors = [] + todo = [tensor] + while todo: + obj = todo.pop() + inner_tensors, metadata = obj.__tensor_flatten__() + rebuild_stack.append((type(obj), metadata, inner_tensors)) + for attr_name in inner_tensors: + val = getattr(obj, attr_name) + if type(val) is torch.Tensor: + plain_tensors.append(val) + else: + assert isinstance(val, torch.Tensor) + todo.append(val) + + self.rebuild_stack = rebuild_stack + + return plain_tensors + +def unwrap_tensor_subclass(model, filter_fn=None): + for name, child in model.named_children(): + # make sure child.weight is a tensor subclass + if ( + isinstance(child, torch.nn.Linear) and + hasattr(child, "weight") and + type(child.weight) is not torch.Tensor and + type(child.weight) is not torch.nn.Parameter and + isinstance(child.weight, torch.Tensor) and + issubclass(type(child.weight), torch.Tensor) + ): + parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) + unwrap_tensor_subclass(child) + return model + +if version.parse(torch.__version__) >= version.parse("2.4.0.dev"): + TORCH_VERSION_AFTER_2_4 = True +else: + TORCH_VERSION_AFTER_2_4 = False + +if version.parse(torch.__version__) >= version.parse("2.3.0.dev"): + TORCH_VERSION_AFTER_2_3 = True +else: + TORCH_VERSION_AFTER_2_3 = False + +if version.parse(torch.__version__) >= version.parse("2.2.0.dev"): + TORCH_VERSION_AFTER_2_2 = True +else: + TORCH_VERSION_AFTER_2_2 = False