Skip to content

Commit

Permalink
Improve FP6-LLM 2+4bit weight splitting + user API (pytorch#279)
Browse files Browse the repository at this point in the history
* add annotation

* add weight splitting logic

* update from fp6_quant

* merge to_tc_float6_e3m2

* add more optimized version

* add some notes

* add from_tc_float6_e3m2

* add some docs

* make fp6_llm.py

* add test for linear

* fix fp6 llm

* switch to v2 since it's faster

* fix type hint for old python

* simplify further

* fix typing for old python

* add test

* eliminate indexing.faster on CUDA

* skip fp6_llm on cpu

* improve error message

* add support for extra batch dims

* cast output to original dtype

* fix precision error due to dtype
  • Loading branch information
gau-nernst authored May 26, 2024
1 parent 231116a commit 9938a3f
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFp6(TestCase):
class TestFloat6E3M2(TestCase):

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_from_float6_e3m2_compile(self, device, no_bit_packing):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFp6)
instantiate_parametrized_tests(TestFloat6E3M2)


if __name__ == "__main__":
Expand Down
99 changes: 99 additions & 0 deletions test/quantization/test_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm
from torchao.ops import prepack_fp6_weight


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFp6LlmLinear(TestCase):
@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)

expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8)
actual = to_tc_float6_e3m2(x)
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1))

@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_compile(self, device):
x = torch.randn(256, 64, device=device)

expected = to_tc_float6_e3m2(x)
actual = torch.compile(to_tc_float6_e3m2)(x)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape)
torch.testing.assert_close(actual, x)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_compile(self, device):
M, N = 256, 64
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x, M, N)
actual = torch.compile(from_tc_float6_e3m2)(x, M, N)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("leading_dims", [(4,), (2, 4)])
@parametrize("bias", [False, True])
def test_fp6_llm_linear_forward(self, bias, leading_dims):
OC, IC = 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fp6_linear = Fp6LlmLinear.from_float(linear)
assert (fp6_linear.bias is not None) == bias

x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
fp6_linear(x)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("bias", [False, True])
def test_fp6_llm_linear_compile(self, bias):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fp6_linear = Fp6LlmLinear.from_float(linear)

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fp6_linear(x)
actual = torch.compile(fp6_linear)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_convert_fp6_llm(self):
device = "cuda"
model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device)
convert_fp6_llm(model)

assert isinstance(model[0], Fp6LlmLinear)
assert model[0].bias is None
assert isinstance(model[1], Fp6LlmLinear)
assert model[1].bias is not None

x = torch.randn(4, 64, device=device)
model(x)


instantiate_parametrized_tests(TestFp6LlmLinear)


if __name__ == "__main__":
run_tests()
4 changes: 2 additions & 2 deletions torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats,
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
int num_out_channels = _weights.size(0);
assert( num_in_channels%64 == 0 );
assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched.
TORCH_CHECK(num_in_channels%64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels);
TORCH_CHECK((num_in_channels/16*3) == _weights.size(1)); // Making sure the K dimension is matched.
//
int M = num_out_channels;
int K = num_in_channels;
Expand Down
160 changes: 160 additions & 0 deletions torchao/quantization/fp6_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Optional

import torch
from torch import nn, Tensor
from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2
from torchao.ops import fp16act_fp6weight_linear


def _pack_2bit(x: Tensor) -> Tensor:
return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4]


def _unpack_2bit(x: Tensor) -> Tensor:
return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2)


def _pack_4bit(x: Tensor) -> Tensor:
return (x[..., ::2] << 4) | x[..., 1::2]


def _unpack_4bit(x: Tensor) -> Tensor:
return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2)


# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing
# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h
def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor:
assert tensor.ndim == 2
M, N = tensor.shape
assert (M % 64 == 0) and (N % 64 == 0)

tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True)

# Pass 1 from original code
tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8)
tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6)
tensor_fp6 = tensor_fp6.reshape(-1, 32, 2)
tensor_fp6 = tensor_fp6.permute(1, 0, 2)
tensor_fp6 = tensor_fp6.flatten()

tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11)
tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111)

# Pass 2 from original code
tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2)
tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2)

# Pass 3 from original code
# BitInterleaving_2bit
# the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32
# while we still unpack/pack the values from/to uint8
tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16)
tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]]
tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]]
tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]]
tensor_2bit = _pack_2bit(tensor_2bit).view(-1)

# BitInterleaving_4bit
# the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32
# while we still unpack/pack the values from/to uint8
tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8)
tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]]
tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]]
tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]]
tensor_4bit = _pack_4bit(tensor_4bit).view(-1)

return torch.cat([tensor_2bit, tensor_4bit], dim=0)


# more optimized version of _to_tc_float6_e3m2_original() by merging ops
# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h
def to_tc_float6_e3m2(tensor: Tensor) -> Tensor:
assert tensor.ndim == 2
M, N = tensor.shape
assert (M % 64 == 0) and (N % 64 == 0)

tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True)
tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8)
tensor_fp6 = tensor_fp6.flip(3)

tensor_2bit = (tensor_fp6 >> 4) & 0b11
tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6)
tensor_2bit = _pack_2bit(tensor_2bit.flatten())

tensor_4bit = tensor_fp6 & 0b1111
tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6)
tensor_4bit = _pack_4bit(tensor_4bit.flatten())

return torch.cat([tensor_2bit, tensor_4bit], dim=0)


def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor:
assert tensor.ndim == 1
assert (M % 64 == 0) and (N % 64 == 0)
size_2bit = M * N // 4
size_4bit = M * N // 2
assert tensor.numel() == size_2bit + size_4bit

tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit])

tensor_2bit = _unpack_2bit(tensor_2bit)
tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2)
tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4)

tensor_4bit = _unpack_4bit(tensor_4bit)
tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2)
tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5)

tensor_fp6 = (tensor_2bit << 4) | tensor_4bit
tensor_fp6 = tensor_fp6.flip(3).reshape(M, N)
return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype)


class Fp6LlmLinear(nn.Module):
"""FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112.
"""

def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None:
super().__init__()
self.register_buffer("weight", weight)
self.register_buffer("scales", scales)
self.register_buffer("bias", bias)
self.out_features = weight.shape[0]
self.in_features = weight.shape[1] * 16 // 3

def forward(self, x: Tensor) -> Tensor:
# TODO: splitK map
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1)
if self.bias is not None:
out = out + self.bias
return out.view(*x.shape[:-1], self.out_features).to(x.dtype)

@classmethod
def from_float(cls, linear: nn.Linear):
assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0)

fp32_weight = linear.weight.detach().float()
scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX
scales[scales == 0.0] = 1.0 # avoid 0 scale

tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1))
tc_fp6_weight = tc_fp6_weight.view(linear.out_features, -1).view(torch.int32)

bias = linear.bias.detach().half() if linear.bias is not None else None
return cls(tc_fp6_weight, scales.half(), bias)

def extra_repr(self) -> str:
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'


def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None:
for name, child in model.named_children():
new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}"

if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)):
if (child.in_features % 64 == 0) and (child.out_features % 256 == 0):
new_child = Fp6LlmLinear.from_float(child)
setattr(model, name, new_child)
else:
convert_fp6_llm(child, skip_fqn_list, new_fqn)

0 comments on commit 9938a3f

Please sign in to comment.