Skip to content

Commit

Permalink
Add Nf4Linear and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-varma committed Mar 5, 2024
1 parent 969038f commit 233efd3
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 4 deletions.
115 changes: 115 additions & 0 deletions test/modules/test_nf4_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import logging
import unittest

import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
import torch.nn.functional as F


bnb_available = False

try:
import bitsandbytes as bnb

bnb_available = True
except ImportError:
pass

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


def _build_input_weight(embed_dim: int, device: torch.device):
torch.manual_seed(0)
input_weight = torch.empty(
embed_dim, embed_dim, device=device, dtype=torch.bfloat16
)
input_weight.normal_(0, 1)
return input_weight

def _build_bnb_linear(input_weight, device):
assert bnb_available, "Needs bitsandbytes support"
param = bnb.nn.Params4bit(
input_weight, requires_grad=False, quant_type="nf4"
).cuda(device)
bnb_linear = bnb.nn.LinearNF4(
input_weight.size(0), input_weight.size(1), bias=False
)
bnb_linear.weight = param
bnb_linear.to(device)
return bnb_linear


class TestNF4Linear(TestCase):
def test_output_bf16(self):
# Test to ensure W4 A16 produces A16
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
nf4_tensor = NF4Tensor.from_tensor(
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
)
out = linear_nf4(input=inp, weight=nf4_tensor)
assert out.dtype == torch.bfloat16

def test_backward_bf16(self):
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
# to the linear's weight, as it is frozen.
nf4_tensor = NF4Tensor.from_tensor(
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
)
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
linear_nf4(inp, nf4_tensor).sum().backward()
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
assert nf4_tensor.grad is None

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_reconstruction_qlora_vs_bnb(self):
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
torch.manual_seed(0)
device = "cuda"
embed_dim = 512
input_weight = _build_input_weight(embed_dim, device)
nf4_weight = NF4Tensor.from_tensor(input_weight)
bnb_linear = _build_bnb_linear(input_weight, device)
bnb_reconstruction = bnb_linear(
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
)
bnb_diff = (bnb_reconstruction.T - input_weight).abs().max()
nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max()
# Since we are subtle different we assume that we both reconstruct with
# a similar precision
assert bnb_diff < 1
assert nugs_diff < 1
assert (nugs_diff - bnb_diff).abs() < 2e-1

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_nf4_bnb_linear(self):
"""
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
error compared to a bf16 linear is not more than BNB's implementation.
"""
torch.manual_seed(0)
dim = 512
device = "cuda"
input_weight = _build_input_weight(dim, device)
nf4_weight = NF4Tensor.from_tensor(input_weight)
bnb_linear = _build_bnb_linear(input_weight, device)

inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")

out_nf4 = linear_nf4(inp, nf4_weight).sum()
out_bnb = bnb_linear(inp).sum()
out_ref = F.linear(inp, input_weight).sum()

err_bnb = (out_bnb - out_ref).abs().max()
err_native = (out_nf4 - out_ref).abs().max()
assert err_native < 0.5 * dim
assert err_bnb < 0.5 * dim


if __name__ == "__main__":
unittest.main()
48 changes: 44 additions & 4 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,33 @@
from typing import Dict, Tuple

import torch
from torch import Tensor
import torch.nn.functional as F


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional

from typing import Any
NF4_OPS_TABLE: Dict[Any, Any] = {}



def implements(aten_ops):
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""

def decorator(func):
for op in aten_ops:
NF4_OPS_TABLE[op] = func
return func

return decorator

@implements([torch.ops.aten.detach.default, torch.ops.aten.detach])
def noop_detach(func, *args, **kwargs):
return args[0][0]


@dataclass
class SubclassTensorArgs:
original_shape: torch.Size
Expand Down Expand Up @@ -110,7 +134,7 @@ def from_tensor(
assert inpt_tensor.dtype == torch.bfloat16
assert (
inpt_tensor.numel() % block_size == 0
), "Input tensor must be divisible by block size"
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16"
assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!"
# I think I want do this
Expand Down Expand Up @@ -204,7 +228,7 @@ def double_quantize_scalers(
# Second round of quantization
assert (
scalers_1.numel() % scaler_block_size == 0
), "Number of scalers must be divisible by scaler block size"
), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
n_scaler_blocks = scalers_1.numel() // scaler_block_size
scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size)

Expand Down Expand Up @@ -397,12 +421,28 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
"""TODO we are not supporting torch dispatch at the moment
instead we have created a Autograd.Function to handle the linear
"""
raise NotImplementedError("NF4Tensor does not support torch dispatch")
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
# And don't support mixed tensor subclasses. This will trigger the handler for
# the next type in the dispatch list
def allowed_subclasses(type):
return (
issubclass(cls, type)
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
)

if not all(allowed_subclasses(t) for t in types):
return NotImplemented("Up to the next one to handle")

if func in NF4_OPS_TABLE:
return NF4_OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(
f"NF4Tensor dispatch: attempting to run {func}, this is not supported"
)

# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl


class LinearNF4(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
Expand Down

0 comments on commit 233efd3

Please sign in to comment.