diff --git a/test/modules/test_nf4_linear.py b/test/modules/test_nf4_linear.py new file mode 100644 index 0000000000..621824bc6f --- /dev/null +++ b/test/modules/test_nf4_linear.py @@ -0,0 +1,115 @@ +import logging +import unittest + +import torch +from torch import nn +from torch.testing._internal.common_utils import TestCase +from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor +import torch.nn.functional as F + + +bnb_available = False + +try: + import bitsandbytes as bnb + + bnb_available = True +except ImportError: + pass + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +def _build_input_weight(embed_dim: int, device: torch.device): + torch.manual_seed(0) + input_weight = torch.empty( + embed_dim, embed_dim, device=device, dtype=torch.bfloat16 + ) + input_weight.normal_(0, 1) + return input_weight + +def _build_bnb_linear(input_weight, device): + assert bnb_available, "Needs bitsandbytes support" + param = bnb.nn.Params4bit( + input_weight, requires_grad=False, quant_type="nf4" + ).cuda(device) + bnb_linear = bnb.nn.LinearNF4( + input_weight.size(0), input_weight.size(1), bias=False + ) + bnb_linear.weight = param + bnb_linear.to(device) + return bnb_linear + + +class TestNF4Linear(TestCase): + def test_output_bf16(self): + # Test to ensure W4 A16 produces A16 + inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + nf4_tensor = NF4Tensor.from_tensor( + inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) + ) + out = linear_nf4(input=inp, weight=nf4_tensor) + assert out.dtype == torch.bfloat16 + + def test_backward_bf16(self): + # Test to ensure backward pass gives activation a bf16 gradient and no gradient + # to the linear's weight, as it is frozen. + nf4_tensor = NF4Tensor.from_tensor( + inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) + ) + inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + linear_nf4(inp, nf4_tensor).sum().backward() + assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 + assert nf4_tensor.grad is None + + @unittest.skipIf(not bnb_available, "Need bnb availble") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_reconstruction_qlora_vs_bnb(self): + # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 + torch.manual_seed(0) + device = "cuda" + embed_dim = 512 + input_weight = _build_input_weight(embed_dim, device) + nf4_weight = NF4Tensor.from_tensor(input_weight) + bnb_linear = _build_bnb_linear(input_weight, device) + bnb_reconstruction = bnb_linear( + torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device) + ) + bnb_diff = (bnb_reconstruction.T - input_weight).abs().max() + nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max() + # Since we are subtle different we assume that we both reconstruct with + # a similar precision + assert bnb_diff < 1 + assert nugs_diff < 1 + assert (nugs_diff - bnb_diff).abs() < 2e-1 + + @unittest.skipIf(not bnb_available, "Need bnb availble") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_nf4_bnb_linear(self): + """ + This test ensures that nf4_linear is "no worse" than BNB by ensuring the + error compared to a bf16 linear is not more than BNB's implementation. + """ + torch.manual_seed(0) + dim = 512 + device = "cuda" + input_weight = _build_input_weight(dim, device) + nf4_weight = NF4Tensor.from_tensor(input_weight) + bnb_linear = _build_bnb_linear(input_weight, device) + + inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda") + + out_nf4 = linear_nf4(inp, nf4_weight).sum() + out_bnb = bnb_linear(inp).sum() + out_ref = F.linear(inp, input_weight).sum() + + err_bnb = (out_bnb - out_ref).abs().max() + err_native = (out_nf4 - out_ref).abs().max() + assert err_native < 0.5 * dim + assert err_bnb < 0.5 * dim + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index b7ad0c6a33..5b420cc0f7 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -2,9 +2,33 @@ from typing import Dict, Tuple import torch +from torch import Tensor import torch.nn.functional as F +aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional + +from typing import Any +NF4_OPS_TABLE: Dict[Any, Any] = {} + + + +def implements(aten_ops): + """Use this decorator to implement a function for an aten op in __torch_dispatch__""" + + def decorator(func): + for op in aten_ops: + NF4_OPS_TABLE[op] = func + return func + + return decorator + +@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + return args[0][0] + + @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -110,7 +134,7 @@ def from_tensor( assert inpt_tensor.dtype == torch.bfloat16 assert ( inpt_tensor.numel() % block_size == 0 - ), "Input tensor must be divisible by block size" + ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16" assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this @@ -204,7 +228,7 @@ def double_quantize_scalers( # Second round of quantization assert ( scalers_1.numel() % scaler_block_size == 0 - ), "Number of scalers must be divisible by scaler block size" + ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -397,12 +421,28 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment instead we have created a Autograd.Function to handle the linear """ - raise NotImplementedError("NF4Tensor does not support torch dispatch") + # All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs + # And don't support mixed tensor subclasses. This will trigger the handler for + # the next type in the dispatch list + def allowed_subclasses(type): + return ( + issubclass(cls, type) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Up to the next one to handle") + + if func in NF4_OPS_TABLE: + return NF4_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError( + f"NF4Tensor dispatch: attempting to run {func}, this is not supported" + ) # Do not force the Float8Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl - class LinearNF4(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor, weight: NF4Tensor):