-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
969038f
commit 233efd3
Showing
2 changed files
with
159 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import logging | ||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import TestCase | ||
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor | ||
import torch.nn.functional as F | ||
|
||
|
||
bnb_available = False | ||
|
||
try: | ||
import bitsandbytes as bnb | ||
|
||
bnb_available = True | ||
except ImportError: | ||
pass | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | ||
) | ||
|
||
|
||
def _build_input_weight(embed_dim: int, device: torch.device): | ||
torch.manual_seed(0) | ||
input_weight = torch.empty( | ||
embed_dim, embed_dim, device=device, dtype=torch.bfloat16 | ||
) | ||
input_weight.normal_(0, 1) | ||
return input_weight | ||
|
||
def _build_bnb_linear(input_weight, device): | ||
assert bnb_available, "Needs bitsandbytes support" | ||
param = bnb.nn.Params4bit( | ||
input_weight, requires_grad=False, quant_type="nf4" | ||
).cuda(device) | ||
bnb_linear = bnb.nn.LinearNF4( | ||
input_weight.size(0), input_weight.size(1), bias=False | ||
) | ||
bnb_linear.weight = param | ||
bnb_linear.to(device) | ||
return bnb_linear | ||
|
||
|
||
class TestNF4Linear(TestCase): | ||
def test_output_bf16(self): | ||
# Test to ensure W4 A16 produces A16 | ||
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) | ||
nf4_tensor = NF4Tensor.from_tensor( | ||
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) | ||
) | ||
out = linear_nf4(input=inp, weight=nf4_tensor) | ||
assert out.dtype == torch.bfloat16 | ||
|
||
def test_backward_bf16(self): | ||
# Test to ensure backward pass gives activation a bf16 gradient and no gradient | ||
# to the linear's weight, as it is frozen. | ||
nf4_tensor = NF4Tensor.from_tensor( | ||
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) | ||
) | ||
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) | ||
linear_nf4(inp, nf4_tensor).sum().backward() | ||
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 | ||
assert nf4_tensor.grad is None | ||
|
||
@unittest.skipIf(not bnb_available, "Need bnb availble") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_reconstruction_qlora_vs_bnb(self): | ||
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 | ||
torch.manual_seed(0) | ||
device = "cuda" | ||
embed_dim = 512 | ||
input_weight = _build_input_weight(embed_dim, device) | ||
nf4_weight = NF4Tensor.from_tensor(input_weight) | ||
bnb_linear = _build_bnb_linear(input_weight, device) | ||
bnb_reconstruction = bnb_linear( | ||
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device) | ||
) | ||
bnb_diff = (bnb_reconstruction.T - input_weight).abs().max() | ||
nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max() | ||
# Since we are subtle different we assume that we both reconstruct with | ||
# a similar precision | ||
assert bnb_diff < 1 | ||
assert nugs_diff < 1 | ||
assert (nugs_diff - bnb_diff).abs() < 2e-1 | ||
|
||
@unittest.skipIf(not bnb_available, "Need bnb availble") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_nf4_bnb_linear(self): | ||
""" | ||
This test ensures that nf4_linear is "no worse" than BNB by ensuring the | ||
error compared to a bf16 linear is not more than BNB's implementation. | ||
""" | ||
torch.manual_seed(0) | ||
dim = 512 | ||
device = "cuda" | ||
input_weight = _build_input_weight(dim, device) | ||
nf4_weight = NF4Tensor.from_tensor(input_weight) | ||
bnb_linear = _build_bnb_linear(input_weight, device) | ||
|
||
inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda") | ||
|
||
out_nf4 = linear_nf4(inp, nf4_weight).sum() | ||
out_bnb = bnb_linear(inp).sum() | ||
out_ref = F.linear(inp, input_weight).sum() | ||
|
||
err_bnb = (out_bnb - out_ref).abs().max() | ||
err_native = (out_nf4 - out_ref).abs().max() | ||
assert err_native < 0.5 * dim | ||
assert err_bnb < 0.5 * dim | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters