diff --git a/test/modules/test_nf4_linear.py b/test/modules/test_nf4_linear.py new file mode 100644 index 0000000000..7f50c60216 --- /dev/null +++ b/test/modules/test_nf4_linear.py @@ -0,0 +1,129 @@ +import logging +import unittest + +import torch +from torch import nn +from torch.testing._internal.common_utils import TestCase +from torchao.modules import FrozenNF4Linear +from torchao.dtypes.nf4tensor import NF4Tensor + +bnb_available = False + +try: + import bitsandbytes as bnb + bnb_available = True +except ImportError: + pass + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +class TestNF4Linear(TestCase): + """ + Test torchao.modules.NF4Linear + """ + def test_bias_unsupported(self): + with self.assertRaisesRegex(RuntimeError, "does not currently support biases"): + _ = FrozenNF4Linear(1, 1, bias=True) + + def test_non_bf16_unsupported(self): + with self.assertRaisesRegex(RuntimeError, "only supported with bf16"): + _ = FrozenNF4Linear(1, 1) + + def test_frozen_nf4_linear(self): + nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) + self.assertTrue(isinstance(nf4_linear.weight, NF4Tensor)) + self.assertEqual(torch.bfloat16, nf4_linear.weight.get_original_weight().dtype) + + def test_output_bf16(self): + # Test to ensure W4 A16 produces A16 + nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) + inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + out = nf4_linear(inp) + assert out.dtype == torch.bfloat16 + + def test_backward_bf16(self): + # Test to ensure backward pass gives activation a bf16 gradient and no gradient + # to the linear's weight, as it is frozen. + nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) + inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + nf4_linear(inp).sum().backward() + assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 + assert nf4_linear.weight.grad is None + + + def _build_bnb_linear(self, input_weight): + assert bnb_available, "Needs bitsandbytes support" + param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") + bnb_linear = bnb.nn.LinearNF4(input_weight.size(0), input_weight.size(1), bias=False) + bnb_linear.weight = param + bnb_linear.cuda() + return bnb_linear + + @unittest.skipIf(not bnb_available, "Need bnb availble") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fwd_bnb_parity(self): + """ + Ensures fwd + backward logits and grads are at parity w/bnb + """ + nf4_linear = FrozenNF4Linear(512, 512, device='cuda', dtype=torch.bfloat16) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) + + inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda', requires_grad=True) + with torch.no_grad(): + inp_bnb = inp.clone() + inp_bnb.requires_grad_(True) + out_native = nf4_linear(inp).sum() + out_bnb = bnb_nf4_linear(inp_bnb).sum() + self.assertEqual(out_native, out_bnb) + + + @unittest.skipIf(not bnb_available, "Need bnb availble") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_nf4_reconstruction_vs_bnb(self): + """ + Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when + reconstructing the respective original weights. + """ + dim = 512 + nf4_linear = FrozenNF4Linear(dim, dim, device='cuda', dtype=torch.bfloat16) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) + + # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 + bnb_reconstruction = bnb_nf4_linear( + torch.eye(dim, dim, dtype=torch.bfloat16, device='cuda') + ) + # Ensure nf4_linear and bnb reconstructions are close to each other. + diff = (bnb_reconstruction.T - nf4_linear.weight.get_original_weight()).abs().max() + assert diff.item() < 1e-2 + + @unittest.skipIf(not bnb_available, "Need bnb availble") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_nf4_bnb_linear(self): + """ + This test ensures that nf4_linear is "no worse" than BNB by ensuring the + error compared to a bf16 linear is not more than BNB's implementation. + """ + dim = 512 + nf4_linear = FrozenNF4Linear(dim, dim, device='cuda', dtype=torch.bfloat16) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) + bf16_linear = torch.nn.Linear(dim, dim, device='cuda', dtype=torch.bfloat16) + + inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda') + + out_nf4 = nf4_linear(inp).sum() + out_bnb = bnb_nf4_linear(inp).sum() + out_ref = bf16_linear(inp).sum() + + err_bnb = (out_bnb - out_ref).abs().max() + err_native = (out_nf4 - out_ref).abs().max() + assert err_native.item() <= err_bnb + + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index b7ad0c6a33..5b420cc0f7 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -2,9 +2,33 @@ from typing import Dict, Tuple import torch +from torch import Tensor import torch.nn.functional as F +aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional + +from typing import Any +NF4_OPS_TABLE: Dict[Any, Any] = {} + + + +def implements(aten_ops): + """Use this decorator to implement a function for an aten op in __torch_dispatch__""" + + def decorator(func): + for op in aten_ops: + NF4_OPS_TABLE[op] = func + return func + + return decorator + +@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + return args[0][0] + + @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -110,7 +134,7 @@ def from_tensor( assert inpt_tensor.dtype == torch.bfloat16 assert ( inpt_tensor.numel() % block_size == 0 - ), "Input tensor must be divisible by block size" + ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16" assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this @@ -204,7 +228,7 @@ def double_quantize_scalers( # Second round of quantization assert ( scalers_1.numel() % scaler_block_size == 0 - ), "Number of scalers must be divisible by scaler block size" + ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -397,12 +421,28 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment instead we have created a Autograd.Function to handle the linear """ - raise NotImplementedError("NF4Tensor does not support torch dispatch") + # All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs + # And don't support mixed tensor subclasses. This will trigger the handler for + # the next type in the dispatch list + def allowed_subclasses(type): + return ( + issubclass(cls, type) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Up to the next one to handle") + + if func in NF4_OPS_TABLE: + return NF4_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError( + f"NF4Tensor dispatch: attempting to run {func}, this is not supported" + ) # Do not force the Float8Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl - class LinearNF4(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor, weight: NF4Tensor): diff --git a/torchao/modules/__init__.py b/torchao/modules/__init__.py new file mode 100644 index 0000000000..1ae2ca43b5 --- /dev/null +++ b/torchao/modules/__init__.py @@ -0,0 +1 @@ +from .nf4_linear import FrozenNF4Linear diff --git a/torchao/modules/nf4_linear.py b/torchao/modules/nf4_linear.py new file mode 100644 index 0000000000..0136316c16 --- /dev/null +++ b/torchao/modules/nf4_linear.py @@ -0,0 +1,35 @@ +import torch + +import torch.nn as nn +from torch import Tensor +from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4 + + +class FrozenNF4Linear(nn.Linear): + """ + A linear layer similar to ``torch.nn.Linear`` but uses a quantized + NF4Tensor as its weight. This class also freezes its ``weight`` parameter + and is meant to be used as the base Linear layer for modeling + use cases such as QLoRA where base model parameters are frozen. + + NOTE: biases are currently not supported. + """ + def __init__(self, in_dim: int, out_dim: int, bias: bool = False, device=None, dtype=None, **kwargs): + if bias: + raise RuntimeError("FrozenNF4Linear does not currently support biases!") + + super().__init__(in_dim, out_dim, device=device, dtype=dtype, **kwargs) + self.weight.requires_grad_(False) + if self.weight.dtype != torch.bfloat16: + raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently") + + self.nf4_weight = NF4Tensor.from_tensor(self.weight.data).to(device).to(dtype) + # re-register self.weight as the nf4 weight's original precision + del self.weight + self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) + + # TODO: likely need to handle state_dict save & load via hooks to properly manage + # types. + + def forward(self, input: Tensor) -> Tensor: + return linear_nf4(input=input, weight=self.weight)