diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 5503520a37..1cdf2708a0 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -1,10 +1,18 @@ +import copy import logging import unittest from packaging import version import math - +import pytest import torch from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + CheckpointWrapper, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -431,6 +439,170 @@ def test_to_cpu(self): inner_tensor = getattr(nf4_tensor, attr) self.assertEqual(inner_tensor.device.type, "cpu") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_deepcopy(self, input_size: Union[Tuple[int], int]): + nf4_orig = to_nf4(torch.randn(input_size, device="cuda")) + nf4_clone = copy.deepcopy(nf4_orig) + self.assertEqual( + nf4_clone.get_original_weight(), nf4_orig.get_original_weight() + ) + + +class LoRALinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + weight: torch.Tensor, + rank: int, + alpha: float, + dropout: float = 0.0, + ): + super().__init__() + self.in_dim = in_dim + self.rank = rank + self.alpha = alpha + self.out_dim = out_dim + self.register_parameter("weight", nn.Parameter(to_nf4(weight))) + self.dropout = nn.Dropout(p=dropout) + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = linear_nf4(input=x, weight=self.weight) + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + return out + lora_out + + +class TestQLoRA(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", + reason="torch >= 2.4 required", + ) + @skip_if_lt_x_gpu(2) + def test_qlora_fsdp2(self): + from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy + + self.run_subtests( + { + "enable_activation_checkpointing": [False, True], + "offload_policy": [ + OffloadPolicy(), + CPUOffloadPolicy(pin_memory=True), + CPUOffloadPolicy(pin_memory=False), + ], + }, + self._test_qlora_fsdp2, + ) + + def _test_qlora_fsdp2( + self, + enable_activation_checkpointing: bool, + offload_policy: "OffloadPolicy", + ): + from torch.distributed._composable.fsdp import fully_shard + from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, + ) + + batch_size = 3 + lora_r = 8 + lora_alpha = 16 + vocab_size = 1024 + seq_len = 64 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + dim=1024, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + with torch.device("cuda"): + base_model = Transformer(model_args) + for layer in base_model.layers: + # attention with lora adapters + for attr in ["wq", "wk", "wv", "wo"]: + orig_linear = getattr(layer.attention, attr) + setattr( + layer.attention, + attr, + LoRALinear( + orig_linear.weight.shape[1], + orig_linear.weight.shape[0], + orig_linear.weight, + lora_r, + lora_alpha, + ), + ) + for attr in ["w1", "w2"]: + orig_linear = getattr(layer.feed_forward, attr) + setattr( + layer.feed_forward, + attr, + LoRALinear( + orig_linear.weight.shape[1], + orig_linear.weight.shape[0], + orig_linear.weight, + lora_r, + lora_alpha, + ), + ) + for name, param in base_model.named_parameters(): + param.requires_grad_( + name.endswith("lora_a.weight") or name.endswith("lora_b.weight") + ) + if enable_activation_checkpointing: + apply_activation_checkpointing( + base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}) + ) + base_optim = torch.optim.AdamW(base_model.parameters(), lr=1e-2) + + fsdp_kwargs = {"offload_policy": offload_policy} + fsdp_model = copy.deepcopy(base_model) + for m in fsdp_model.modules(): + if enable_activation_checkpointing: + if isinstance(m, CheckpointWrapper): + fully_shard(m, **fsdp_kwargs) + else: + if isinstance(m, TransformerBlock): + fully_shard(m, **fsdp_kwargs) + fully_shard(fsdp_model, **fsdp_kwargs) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + fsdp_loss = fsdp_model(inp).sum() + fsdp_loss.backward() + fsdp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).sum() + base_loss.backward() + for param in base_model.parameters(): + if param.grad is not None: + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.AVG + ) + base_optim.step() + self.assertEqual(fsdp_loss, base_loss) + instantiate_parametrized_tests(TestNF4Linear) instantiate_parametrized_tests(TestFSDPOps) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index f05599f6ef..df2b1f08d9 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -11,10 +11,6 @@ from torch import Tensor from torch.distributed.device_mesh import DeviceMesh from torch._prims_common import make_contiguous_strides_for -from torchao.dtypes.utils import ( - _implements, - _ATEN_OP_OR_TORCH_FN_TABLE, -) aten = torch.ops.aten @@ -23,6 +19,9 @@ from typing import Any, Optional, Tuple, Union, List +NF4_OPS_TABLE: Dict[Any, Any] = {} + + _INNER_TENSOR_NAMES_FOR_SHARDING = ["quantized_scalers", "quantization_factor", "quantized_data"] # Note: Quantize in Chunks @@ -44,6 +43,17 @@ def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): ) +def implements(aten_ops): + """Use this decorator to implement a function for an aten op in __torch_dispatch__""" + + def decorator(func): + for op in aten_ops: + NF4_OPS_TABLE[op] = func + return func + + return decorator + + def construct_nf4_args(nf4tensor: "NF4Tensor", kwargs: Optional[Dict[str, Any]] = None): if kwargs is None: kwargs = {} @@ -121,6 +131,251 @@ def wrapper(aten_op, args, kwargs=None): return decorator +@implements([torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + return args[0][0] + +@implements([torch.ops.aten.clone.default]) +def clone(func, *args, **kwargs): + return to_nf4(args[0][0].get_original_weight()) + +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + if len(args) == 3 and args[2] != 0: + raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") + nf4tensor = args[0] + num_chunks = nf4tensor.size(0) // args[1] + + attr_to_chunks = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.numel() % num_chunks == 0, f"{attr}.numel() not divisible by {num_chunks}" + chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) + attr_to_chunks[attr] = chunks + + orig_dim = nf4tensor.dim() + if orig_dim == 1: + chunked_size = (nf4tensor.size(0) // num_chunks, ) + elif orig_dim == 2: + chunked_size = (nf4tensor.size(0) // num_chunks, nf4tensor.size(1)) + else: + chunked_size = () + raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") + + nf4_chunks = [] + for idx in range(num_chunks): + updated_attrs = { + "size": chunked_size + } + for attr, chunks in attr_to_chunks.items(): + updated_attrs[attr] = chunks[idx] + nf4_chunks.append(NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))) + return nf4_chunks + +@implements( + [ + aten.new_zeros.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") +def nf4_new_zeros(aten_op, args, kwargs=None): + nf4tensor = args[0] + new_size = tuple(args[1]) + new_size_dim = len(new_size) + if nf4tensor.numel() % math.prod(new_size) != 0: + raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") + ratio = nf4tensor.numel() // math.prod(new_size) + + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.size(0) % ratio == 0, f"{attr}.numel() must be divisible by {ratio}" + inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) + updated_attrs[attr] = inner_tensor + updated_attrs["size"] = new_size + + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.slice.Tensor, + ] +) +@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") +@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") +@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") +def nf4_slice(aten_op, args, kwargs=None): + nf4tensor = args[0] + # for tensor 512 x 512, tensor[:, :512] dispatch to + # aten.slice(dim = 0, end=sys.maxsize) + if not args[3] in [nf4tensor.size(0), sys.maxsize]: + raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") + return NF4Tensor(*construct_nf4_args(nf4tensor)) + +@implements( + [ + aten.view.default, + ] +) +@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +def nf4_view(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update({ + "size": [nf4tensor.numel()], + "stride": (1, ), + }) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.as_strided.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") +def nf4_as_strided(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + stride = tuple(args[2]) + storage_offset = args[3] + if math.prod(size) != nf4tensor.numel(): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") + if stride != make_contiguous_strides_for(size): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") + if nf4tensor.storage_offset() != storage_offset: + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") + kwargs = { + "size": torch.Size(size), + "stride": stride, + "storage_offset": storage_offset, + } + return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) + + +@implements([torch.ops.aten._to_copy.default]) +def _to_copy(func, *args, **kwargs): + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return func(args[0][0].t()).t() + out = args[0][0].get_original_weight().to(args[1]["dtype"]) + if "device" in args[1]: + out = out.to(args[1]["device"]) + return out + + +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, *args, **kwargs): + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() + return args[0][0].get_original_weight().to(args[0][1]) + + +@implements([torch.ops.aten.t.default]) +def t_default(func, *args, **kwargs): + a = args[0][0] + tensor_meta = SubclassTensorArgs( + a.size(), + (a.stride(1), a.stride(0)), + a.storage_offset(), + a.dtype, + a.device, + a.requires_grad, + ) + b = NF4Tensor( + tensor_meta, + a.block_size, + a.n_blocks, + a.scaler_block_size, + a.quantized_scalers, + a.quantization_factor, + a.scaler_mean, + a.quantized_data, + a.nf4, + ) + return b + + +@implements([torch.ops.aten.mm.default]) +def mm_default(func, *args, **kwargs): + return linear_nf4(args[0][0], args[0][1]) + + +@implements( + [ + aten.copy_.default, + ] +) +def copy_(func, *args, **kwargs): + original: NF4Tensor = args[0][0] + copy_in: torch.Tensor = args[0][1] + + # Base Case + + if same_metadata(original, copy_in): + original_tensors = original.__tensor_flatten__()[0] + for tensor_name in original_tensors: + getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) + return + + # Convert Non NF4Tensor into NF4 for copy in + if not isinstance(copy_in, NF4Tensor): + copy_in_nf4 = NF4Tensor.from_tensor( + copy_in, original.block_size, original.scaler_block_size + ) + return original.copy_(copy_in_nf4) + + # Other Tensor is not a NF4Tensor + full_precision = copy_in.get_original_weight() + same_meta_nf4 = NF4Tensor.from_tensor( + full_precision, original.block_size, original.scaler_block_size + ) + return original.copy_(same_meta_nf4) + + +@implements( + [ + aten.is_pinned.default, + ] +) +def nf4_is_pinned(aten_op, args, kwargs=None): + nf4tensor = args[0] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + if not aten_op(inner_tensor, *(args[1:]), **kwargs): + return False + return True + + +@implements( + [ + aten._pin_memory.default, + ] +) +def nf4_pin_memory(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -507,7 +762,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment instead we have created a Autograd.Function to handle the linear """ - # All ops in the _ATEN_OP_OR_TORCH_FN_TABLE expect NF4 Tensors as inputs + # All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs # And don't support mixed tensor subclasses. This will trigger the handler for # the next type in the dispatch list @@ -523,8 +778,8 @@ def allowed_subclasses(type): if not all(allowed_subclasses(t) for t in types): return NotImplemented("Up to the next one to handle") - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, args, kwargs) + if func in NF4_OPS_TABLE: + return NF4_OPS_TABLE[func](func, args, kwargs) raise NotImplementedError( f"NF4Tensor dispatch: attempting to run {func}, this is not supported" ) @@ -537,8 +792,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} try: - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) + if func in NF4_TORCH_FUNCTIONS: + return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) except NotImplementedError: pass @@ -634,251 +889,20 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) -def implements(aten_ops_or_torch_fn): - return _implements(NF4Tensor, aten_ops_or_torch_fn) - -@implements([torch.ops.aten.detach]) -def noop_detach(func, *args, **kwargs): - return args[0][0] - - -@implements( - [ - aten.detach.default, - ] -) -def nf4_detach(aten_op, args, kwargs=None): - nf4tensor = args[0] - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - - -@implements( - [ - aten.split.Tensor, - ] -) -def nf4_split(aten_op, args, kwargs=None): - if len(args) == 3 and args[2] != 0: - raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") - nf4tensor = args[0] - num_chunks = nf4tensor.size(0) // args[1] - - attr_to_chunks = {} - for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: - inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.numel() % num_chunks == 0, f"{attr}.numel() not divisible by {num_chunks}" - chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) - attr_to_chunks[attr] = chunks - - orig_dim = nf4tensor.dim() - if orig_dim == 1: - chunked_size = (nf4tensor.size(0) // num_chunks, ) - elif orig_dim == 2: - chunked_size = (nf4tensor.size(0) // num_chunks, nf4tensor.size(1)) - else: - chunked_size = () - raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") - - nf4_chunks = [] - for idx in range(num_chunks): - updated_attrs = { - "size": chunked_size - } - for attr, chunks in attr_to_chunks.items(): - updated_attrs[attr] = chunks[idx] - nf4_chunks.append(NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))) - return nf4_chunks - -@implements( - [ - aten.new_zeros.default, - ] -) -@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") -def nf4_new_zeros(aten_op, args, kwargs=None): - nf4tensor = args[0] - new_size = tuple(args[1]) - new_size_dim = len(new_size) - if nf4tensor.numel() % math.prod(new_size) != 0: - raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") - ratio = nf4tensor.numel() // math.prod(new_size) - - updated_attrs = {} - for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: - inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.size(0) % ratio == 0, f"{attr}.numel() must be divisible by {ratio}" - inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) - updated_attrs[attr] = inner_tensor - updated_attrs["size"] = new_size - - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - -@implements( - [ - aten.slice.Tensor, - ] -) -@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") -@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") -@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") -def nf4_slice(aten_op, args, kwargs=None): - nf4tensor = args[0] - # for tensor 512 x 512, tensor[:, :512] dispatch to - # aten.slice(dim = 0, end=sys.maxsize) - if not args[3] in [nf4tensor.size(0), sys.maxsize]: - raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") - return NF4Tensor(*construct_nf4_args(nf4tensor)) - -@implements( - [ - aten.view.default, - ] -) -@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") -def nf4_view(aten_op, args, kwargs=None): - nf4tensor = args[0] - size = args[1] - if size[0] != -1: - raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - updated_attrs.update({ - "size": [nf4tensor.numel()], - "stride": (1, ), - }) - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - -@implements( - [ - aten.as_strided.default, - ] -) -@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") -def nf4_as_strided(aten_op, args, kwargs=None): - nf4tensor = args[0] - size = args[1] - stride = tuple(args[2]) - storage_offset = args[3] - if math.prod(size) != nf4tensor.numel(): - raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") - if stride != make_contiguous_strides_for(size): - raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") - if nf4tensor.storage_offset() != storage_offset: - raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") - kwargs = { - "size": torch.Size(size), - "stride": stride, - "storage_offset": storage_offset, - } - return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) - - -@implements([torch.ops.aten._to_copy.default]) -def _to_copy(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return func(args[0][0].t()).t() - out = args[0][0].get_original_weight().to(args[1]["dtype"]) - if "device" in args[1]: - out = out.to(args[1]["device"]) - return out - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() - return args[0][0].get_original_weight().to(args[0][1]) - - -@implements([torch.ops.aten.t.default]) -def t_default(func, *args, **kwargs): - a = args[0][0] - tensor_meta = SubclassTensorArgs( - a.size(), - (a.stride(1), a.stride(0)), - a.storage_offset(), - a.dtype, - a.device, - a.requires_grad, - ) - b = NF4Tensor( - tensor_meta, - a.block_size, - a.n_blocks, - a.scaler_block_size, - a.quantized_scalers, - a.quantization_factor, - a.scaler_mean, - a.quantized_data, - a.nf4, - ) - return b - - -@implements([torch.ops.aten.mm.default]) -def mm_default(func, *args, **kwargs): - return linear_nf4(args[0][0], args[0][1]) +NF4_TORCH_FUNCTIONS = {} -@implements( - [ - aten.copy_.default, - ] -) -def copy_(func, *args, **kwargs): - original: NF4Tensor = args[0][0] - copy_in: torch.Tensor = args[0][1] - # Base Case - - if same_metadata(original, copy_in): - original_tensors = original.__tensor_flatten__()[0] - for tensor_name in original_tensors: - getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) - return - - # Convert Non NF4Tensor into NF4 for copy in - if not isinstance(copy_in, NF4Tensor): - copy_in_nf4 = NF4Tensor.from_tensor( - copy_in, original.block_size, original.scaler_block_size - ) - return original.copy_(copy_in_nf4) - - # Other Tensor is not a NF4Tensor - full_precision = copy_in.get_original_weight() - same_meta_nf4 = NF4Tensor.from_tensor( - full_precision, original.block_size, original.scaler_block_size - ) - return original.copy_(same_meta_nf4) - - -@implements( - [ - aten.is_pinned.default, - ] -) -def nf4_is_pinned(aten_op, args, kwargs=None): - nf4tensor = args[0] - for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: - inner_tensor = getattr(nf4tensor, attr) - if not aten_op(inner_tensor, *(args[1:]), **kwargs): - return False - return True +def implements_torch_function(torch_function): + def decorator(func): + functools.update_wrapper(func, torch_function) + NF4_TORCH_FUNCTIONS[torch_function] = func + return func + return decorator -@implements( - [ - aten._pin_memory.default, - ] -) -def nf4_pin_memory(aten_op, args, kwargs=None): - nf4tensor = args[0] - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) -@implements(torch.Tensor.to) +@implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): tensor = args[0] if isinstance(args[1], torch.dtype): @@ -904,7 +928,7 @@ def function_to_dtype(*args, **kwargs): ) -@implements(torch.Tensor.cpu) +@implements_torch_function(torch.Tensor.cpu) def function_cpu(*args, **kwargs): nf4tensor = args[0] updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs)