forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve FP6-LLM 2+4bit weight splitting + user API (pytorch#279)
* add annotation * add weight splitting logic * update from fp6_quant * merge to_tc_float6_e3m2 * add more optimized version * add some notes * add from_tc_float6_e3m2 * add some docs * make fp6_llm.py * add test for linear * fix fp6 llm * switch to v2 since it's faster * fix type hint for old python * simplify further * fix typing for old python * add test * eliminate indexing.faster on CUDA * skip fp6_llm on cpu * improve error message * add support for extra batch dims * cast output to original dtype * fix precision error due to dtype
- Loading branch information
1 parent
231116a
commit 9938a3f
Showing
4 changed files
with
263 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 | ||
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm | ||
from torchao.ops import prepack_fp6_weight | ||
|
||
|
||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
class TestFp6LlmLinear(TestCase): | ||
@parametrize("device", _DEVICES) | ||
def test_to_tc_float6_e3m2_correctness(self, device): | ||
x = torch.randn(256, 64, device=device) | ||
|
||
expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) | ||
actual = to_tc_float6_e3m2(x) | ||
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_to_tc_float6_e3m2_compile(self, device): | ||
x = torch.randn(256, 64, device=device) | ||
|
||
expected = to_tc_float6_e3m2(x) | ||
actual = torch.compile(to_tc_float6_e3m2)(x) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_from_tc_float6_e3m2_correctness(self, device): | ||
x = torch.randn(256, 64, device=device) | ||
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6 | ||
|
||
actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape) | ||
torch.testing.assert_close(actual, x) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_from_tc_float6_e3m2_compile(self, device): | ||
M, N = 256, 64 | ||
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device) | ||
|
||
expected = from_tc_float6_e3m2(x, M, N) | ||
actual = torch.compile(from_tc_float6_e3m2)(x, M, N) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@parametrize("leading_dims", [(4,), (2, 4)]) | ||
@parametrize("bias", [False, True]) | ||
def test_fp6_llm_linear_forward(self, bias, leading_dims): | ||
OC, IC = 256, 64 | ||
device = "cuda" | ||
|
||
linear = torch.nn.Linear(IC, OC, bias=bias, device=device) | ||
fp6_linear = Fp6LlmLinear.from_float(linear) | ||
assert (fp6_linear.bias is not None) == bias | ||
|
||
x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) | ||
fp6_linear(x) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@parametrize("bias", [False, True]) | ||
def test_fp6_llm_linear_compile(self, bias): | ||
N, OC, IC = 4, 256, 64 | ||
device = "cuda" | ||
|
||
linear = torch.nn.Linear(IC, OC, bias=bias, device=device) | ||
fp6_linear = Fp6LlmLinear.from_float(linear) | ||
|
||
x = torch.randn(N, IC, device=device, dtype=torch.half) | ||
expected = fp6_linear(x) | ||
actual = torch.compile(fp6_linear)(x) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_convert_fp6_llm(self): | ||
device = "cuda" | ||
model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) | ||
convert_fp6_llm(model) | ||
|
||
assert isinstance(model[0], Fp6LlmLinear) | ||
assert model[0].bias is None | ||
assert isinstance(model[1], Fp6LlmLinear) | ||
assert model[1].bias is not None | ||
|
||
x = torch.randn(4, 64, device=device) | ||
model(x) | ||
|
||
|
||
instantiate_parametrized_tests(TestFp6LlmLinear) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2 | ||
from torchao.ops import fp16act_fp6weight_linear | ||
|
||
|
||
def _pack_2bit(x: Tensor) -> Tensor: | ||
return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] | ||
|
||
|
||
def _unpack_2bit(x: Tensor) -> Tensor: | ||
return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) | ||
|
||
|
||
def _pack_4bit(x: Tensor) -> Tensor: | ||
return (x[..., ::2] << 4) | x[..., 1::2] | ||
|
||
|
||
def _unpack_4bit(x: Tensor) -> Tensor: | ||
return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) | ||
|
||
|
||
# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing | ||
# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h | ||
def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: | ||
assert tensor.ndim == 2 | ||
M, N = tensor.shape | ||
assert (M % 64 == 0) and (N % 64 == 0) | ||
|
||
tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) | ||
|
||
# Pass 1 from original code | ||
tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) | ||
tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) | ||
tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) | ||
tensor_fp6 = tensor_fp6.permute(1, 0, 2) | ||
tensor_fp6 = tensor_fp6.flatten() | ||
|
||
tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) | ||
tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) | ||
|
||
# Pass 2 from original code | ||
tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) | ||
tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) | ||
|
||
# Pass 3 from original code | ||
# BitInterleaving_2bit | ||
# the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 | ||
# while we still unpack/pack the values from/to uint8 | ||
tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) | ||
tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] | ||
tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] | ||
tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] | ||
tensor_2bit = _pack_2bit(tensor_2bit).view(-1) | ||
|
||
# BitInterleaving_4bit | ||
# the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 | ||
# while we still unpack/pack the values from/to uint8 | ||
tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) | ||
tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] | ||
tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] | ||
tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] | ||
tensor_4bit = _pack_4bit(tensor_4bit).view(-1) | ||
|
||
return torch.cat([tensor_2bit, tensor_4bit], dim=0) | ||
|
||
|
||
# more optimized version of _to_tc_float6_e3m2_original() by merging ops | ||
# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h | ||
def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: | ||
assert tensor.ndim == 2 | ||
M, N = tensor.shape | ||
assert (M % 64 == 0) and (N % 64 == 0) | ||
|
||
tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) | ||
tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) | ||
tensor_fp6 = tensor_fp6.flip(3) | ||
|
||
tensor_2bit = (tensor_fp6 >> 4) & 0b11 | ||
tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) | ||
tensor_2bit = _pack_2bit(tensor_2bit.flatten()) | ||
|
||
tensor_4bit = tensor_fp6 & 0b1111 | ||
tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) | ||
tensor_4bit = _pack_4bit(tensor_4bit.flatten()) | ||
|
||
return torch.cat([tensor_2bit, tensor_4bit], dim=0) | ||
|
||
|
||
def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: | ||
assert tensor.ndim == 1 | ||
assert (M % 64 == 0) and (N % 64 == 0) | ||
size_2bit = M * N // 4 | ||
size_4bit = M * N // 2 | ||
assert tensor.numel() == size_2bit + size_4bit | ||
|
||
tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) | ||
|
||
tensor_2bit = _unpack_2bit(tensor_2bit) | ||
tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) | ||
tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) | ||
|
||
tensor_4bit = _unpack_4bit(tensor_4bit) | ||
tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) | ||
tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) | ||
|
||
tensor_fp6 = (tensor_2bit << 4) | tensor_4bit | ||
tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) | ||
return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) | ||
|
||
|
||
class Fp6LlmLinear(nn.Module): | ||
"""FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. | ||
""" | ||
|
||
def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None: | ||
super().__init__() | ||
self.register_buffer("weight", weight) | ||
self.register_buffer("scales", scales) | ||
self.register_buffer("bias", bias) | ||
self.out_features = weight.shape[0] | ||
self.in_features = weight.shape[1] * 16 // 3 | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
# TODO: splitK map | ||
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1) | ||
if self.bias is not None: | ||
out = out + self.bias | ||
return out.view(*x.shape[:-1], self.out_features).to(x.dtype) | ||
|
||
@classmethod | ||
def from_float(cls, linear: nn.Linear): | ||
assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) | ||
|
||
fp32_weight = linear.weight.detach().float() | ||
scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX | ||
scales[scales == 0.0] = 1.0 # avoid 0 scale | ||
|
||
tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) | ||
tc_fp6_weight = tc_fp6_weight.view(linear.out_features, -1).view(torch.int32) | ||
|
||
bias = linear.bias.detach().half() if linear.bias is not None else None | ||
return cls(tc_fp6_weight, scales.half(), bias) | ||
|
||
def extra_repr(self) -> str: | ||
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' | ||
|
||
|
||
def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None: | ||
for name, child in model.named_children(): | ||
new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" | ||
|
||
if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): | ||
if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): | ||
new_child = Fp6LlmLinear.from_float(child) | ||
setattr(model, name, new_child) | ||
else: | ||
convert_fp6_llm(child, skip_fqn_list, new_fqn) |