-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
969038f
commit 81f3d38
Showing
4 changed files
with
216 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import logging | ||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import TestCase | ||
from torchao.modules import FrozenNF4Linear | ||
from torchao.dtypes.nf4tensor import NF4Tensor | ||
|
||
bnb_available = False | ||
|
||
try: | ||
import bitsandbytes as bnb | ||
bnb_available = True | ||
except ImportError: | ||
pass | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | ||
) | ||
|
||
class TestNF4Linear(TestCase): | ||
""" | ||
Test torchao.modules.NF4Linear | ||
""" | ||
def test_bias_unsupported(self): | ||
with self.assertRaisesRegex(RuntimeError, "does not currently support biases"): | ||
_ = FrozenNF4Linear(1, 1, bias=True) | ||
|
||
def test_non_bf16_unsupported(self): | ||
with self.assertRaisesRegex(RuntimeError, "only supported with bf16"): | ||
_ = FrozenNF4Linear(1, 1) | ||
|
||
def test_frozen_nf4_linear(self): | ||
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) | ||
self.assertTrue(isinstance(nf4_linear.weight, NF4Tensor)) | ||
self.assertEqual(torch.bfloat16, nf4_linear.weight.get_original_weight().dtype) | ||
|
||
def test_output_bf16(self): | ||
# Test to ensure W4 A16 produces A16 | ||
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) | ||
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) | ||
out = nf4_linear(inp) | ||
assert out.dtype == torch.bfloat16 | ||
|
||
def test_backward_bf16(self): | ||
# Test to ensure backward pass gives activation a bf16 gradient and no gradient | ||
# to the linear's weight, as it is frozen. | ||
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) | ||
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) | ||
nf4_linear(inp).sum().backward() | ||
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 | ||
assert nf4_linear.weight.grad is None | ||
|
||
|
||
def _build_bnb_linear(self, input_weight): | ||
assert bnb_available, "Needs bitsandbytes support" | ||
param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") | ||
bnb_linear = bnb.nn.LinearNF4(input_weight.size(0), input_weight.size(1), bias=False) | ||
bnb_linear.weight = param | ||
bnb_linear.cuda() | ||
return bnb_linear | ||
|
||
@unittest.skipIf(not bnb_available, "Need bnb availble") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_fwd_backward_bnb_parity(self): | ||
""" | ||
Ensures fwd + backward logits and grads are at parity w/bnb | ||
""" | ||
nf4_linear = FrozenNF4Linear(512, 512, device='cuda', dtype=torch.bfloat16) | ||
orig_weight = nf4_linear.weight.get_original_weight().clone().detach() | ||
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) | ||
|
||
inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda', requires_grad=True) | ||
with torch.no_grad(): | ||
inp_bnb = inp.clone() | ||
out_native = nf4_linear(inp).sum() | ||
out_bnb = bnb_nf4_linear(inp_bnb).sum() | ||
self.assertEqual(out_native, out_bnb) | ||
out_native.backward() | ||
out_bnb.backward() | ||
self.assertEqual(out_native.grad.sum(), out_bnb.grad.sum()) | ||
|
||
|
||
@unittest.skipIf(not bnb_available, "Need bnb availble") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_nf4_reconstruction_vs_bnb(self): | ||
""" | ||
Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when | ||
reconstructing the respective original weights. | ||
""" | ||
dim = 512 | ||
nf4_linear = FrozenNF4Linear(dim, dim, device='cuda', dtype=torch.bfloat16) | ||
orig_weight = nf4_linear.weight.get_original_weight().clone().detach() | ||
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) | ||
|
||
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 | ||
bnb_reconstruction = bnb_nf4_linear( | ||
torch.eye(dim, dim, dtype=torch.bfloat16, device='cuda') | ||
) | ||
# Ensure nf4_linear and bnb reconstructions are close to each other. | ||
diff = (bnb_reconstruction.T - nf4_linear.weight.get_original_weight()).abs().max() | ||
assert diff.item() < 1e-2 | ||
|
||
@unittest.skipIf(not bnb_available, "Need bnb availble") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_nf4_bnb_linear(self): | ||
""" | ||
This test ensures that nf4_linear is "no worse" than BNB by ensuring the | ||
error compared to a bf16 linear is not more than BNB's implementation. | ||
""" | ||
dim = 512 | ||
nf4_linear = FrozenNF4Linear(dim, dim, device='cuda', dtype=torch.bfloat16) | ||
orig_weight = nf4_linear.weight.get_original_weight().clone().detach() | ||
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) | ||
bf16_linear = torch.nn.Linear(dim, dim, device='cuda', dtype=torch.bfloat16) | ||
|
||
inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda') | ||
|
||
out_nf4 = nf4_linear(inp).sum() | ||
out_bnb = bnb_nf4_linear(inp).sum() | ||
out_ref = bf16_linear(inp).sum() | ||
|
||
err_bnb = (out_bnb - out_ref).abs().max() | ||
err_native = (out_nf4 - out_ref).abs().max() | ||
assert err_native.item() <= err_bnb | ||
import pdb ; pdb.set_trace() | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .nf4_linear import FrozenNF4Linear |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
|
||
import torch.nn as nn | ||
from torch import Tensor | ||
from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4 | ||
|
||
|
||
class FrozenNF4Linear(nn.Linear): | ||
""" | ||
A linear layer similar to ``torch.nn.Linear`` but uses a quantized | ||
NF4Tensor as its weight. This class also freezes its ``weight`` parameter | ||
and is meant to be used as the base Linear layer for modeling | ||
use cases such as QLoRA where base model parameters are frozen. | ||
NOTE: biases are currently not supported. | ||
""" | ||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, device=None, dtype=None, **kwargs): | ||
if bias: | ||
raise RuntimeError("FrozenNF4Linear does not currently support biases!") | ||
|
||
super().__init__(in_dim, out_dim, device=device, dtype=dtype, **kwargs) | ||
self.weight.requires_grad_(False) | ||
if self.weight.dtype != torch.bfloat16: | ||
raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently") | ||
|
||
self.nf4_weight = NF4Tensor.from_tensor(self.weight.data).to(device).to(dtype) | ||
# re-register self.weight as the nf4 weight's original precision | ||
del self.weight | ||
self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) | ||
|
||
# TODO: likely need to handle state_dict save & load via hooks to properly manage | ||
# types. | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
return linear_nf4(input=input, weight=self.weight) |