Skip to content

Commit

Permalink
Ruff auto-format
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 6, 2024
1 parent 5091d35 commit 1aa8839
Show file tree
Hide file tree
Showing 8 changed files with 675 additions and 253 deletions.
1 change: 1 addition & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .nf4tensor import NF4Tensor, to_nf4

# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .affine_quantized_tensor import (
Expand Down
578 changes: 411 additions & 167 deletions torchao/dtypes/affine_quantized_tensor.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion torchao/dtypes/fpx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP
from .fpx import (
FpxTensorCoreLayoutType,
FpxTensorCoreAQTLayout,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
_SPLIT_K_MAP,
)

__all__ = [
"FpxTensorCoreAQTLayout",
Expand Down
176 changes: 142 additions & 34 deletions torchao/dtypes/fpx/fpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch
from torch import Tensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones
from torchao.prototype.custom_fp_utils import (
_f32_to_fpx_unpacked,
_fpx_unpacked_to_f32,
_n_ones,
)
from torchao.dtypes.utils import (
LayoutType,
)
Expand All @@ -17,11 +21,23 @@


def _pack(x: Tensor, n_bits: int) -> Tensor:
return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])
return reduce(
torch.bitwise_or,
[
x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits)
for i in range(8 // n_bits)
],
)


def _unpack(x: Tensor, n_bits: int) -> Tensor:
return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)
return torch.stack(
[
(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1)
for i in range(8 // n_bits)
],
dim=-1,
).flatten(-2)


# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
Expand All @@ -35,8 +51,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:

if not undo:
bit_order = {
1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
1: [
1,
5,
9,
13,
17,
21,
25,
29,
3,
7,
11,
15,
19,
23,
27,
31,
0,
4,
8,
12,
16,
20,
24,
28,
2,
6,
10,
14,
18,
22,
26,
30,
],
2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
4: [1, 5, 3, 7, 0, 4, 2, 6],
}[n_bits]
Expand All @@ -45,8 +93,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
# this is inverse of the above, obtained by running
# [v.index(i) for i in range(len(v))]
bit_order = {
1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
1: [
16,
0,
24,
8,
17,
1,
25,
9,
18,
2,
26,
10,
19,
3,
27,
11,
20,
4,
28,
12,
21,
5,
29,
13,
22,
6,
30,
14,
23,
7,
31,
15,
],
2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
4: [4, 0, 6, 2, 5, 1, 7, 3],
}[n_bits]
Expand Down Expand Up @@ -82,8 +162,12 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
tensor_ybit = _pack(tensor_ybit, y)

tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code
tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code
tensor_ybit = (
tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
) # Pass 2 from original code
tensor_ybit = _bit_interleave(
tensor_ybit.flatten(), y
) # Pass 3 from original code
fragments.append(tensor_ybit)
used_bits += y

Expand Down Expand Up @@ -125,7 +209,9 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te

# workaround: global lookup table
exp_bias = _ONES_TABLE[ebits - 1]
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (
_ONES_TABLE[mbits + 1] / (2**mbits)
)

tensor = tensor.float()
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
Expand All @@ -151,8 +237,10 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
tensor_ybit = tensor[offset : offset + size_ybit]
offset += size_ybit

tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2
tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
tensor_ybit = (
tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2)
) # undo Pass 2

tensor_ybit = _unpack(tensor_ybit.flatten(), y)
tensor_ybit = tensor_ybit << (nbits - used_bits - y)
Expand Down Expand Up @@ -223,7 +311,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 5,
14336: 7,
28672: 7,
57344: 7
57344: 7,
},
{ # tokens: [65:128]
3072: 9,
Expand All @@ -234,7 +322,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 5,
14336: 7,
28672: 7,
57344: 6
57344: 6,
},
{ # tokens: [129:192]
3072: 6,
Expand All @@ -245,7 +333,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 5,
14336: 5,
28672: 5,
57344: 4
57344: 4,
},
{ # tokens: [193:256]
3072: 9,
Expand All @@ -256,7 +344,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 4,
14336: 8,
28672: 6,
57344: 4
57344: 4,
},
{ # tokens: [257:320]
3072: 7,
Expand All @@ -267,7 +355,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 1,
14336: 3,
28672: 3,
57344: 4
57344: 4,
},
{ # tokens: [321:384]
3072: 3,
Expand All @@ -278,7 +366,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 8,
14336: 3,
28672: 4,
57344: 3
57344: 3,
},
{ # tokens: [385:448]
3072: 5,
Expand All @@ -289,7 +377,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 3,
14336: 1,
28672: 1,
57344: 3
57344: 3,
},
{ # tokens: [449:512]
3072: 2,
Expand All @@ -300,7 +388,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 2,
14336: 6,
28672: 4,
57344: 1
57344: 1,
},
{ # tokens: [513:576]
3072: 2,
Expand All @@ -311,7 +399,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 3,
14336: 3,
28672: 1,
57344: 1
57344: 1,
},
{ # tokens: [577:640]
3072: 5,
Expand All @@ -322,7 +410,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 1,
14336: 1,
28672: 1,
57344: 1
57344: 1,
},
{ # tokens: [641:704]
3072: 3,
Expand All @@ -333,7 +421,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 2,
14336: 1,
28672: 1,
57344: 1
57344: 1,
},
{ # tokens: [705:768]
3072: 3,
Expand All @@ -344,20 +432,22 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
10240: 1,
14336: 1,
28672: 1,
57344: 1
}
57344: 1,
},
]


# quantization api integrations


@dataclass(frozen=True)
class FpxTensorCoreLayoutType(LayoutType):
"""Layout type for FpxTensorCoreAQTLayout
"""
"""Layout type for FpxTensorCoreAQTLayout"""

ebits: int
mbits: int


@register_layout_cls(FpxTensorCoreLayoutType)
class FpxTensorCoreAQTLayout(AQTLayout):
"""FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b),
Expand All @@ -381,6 +471,7 @@ class FpxTensorCoreAQTLayout(AQTLayout):
it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor
FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit)
"""

def __new__(
cls,
packed_fpx_data: torch.Tensor,
Expand All @@ -389,11 +480,16 @@ def __new__(
):
assert packed_fpx_data.ndim == 2
assert packed_fpx_data.dtype == torch.uint8
shape = (packed_fpx_data.shape[0], packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8)
shape = (
packed_fpx_data.shape[0],
packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8,
)
kwargs = {}
kwargs["device"] = packed_fpx_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else packed_fpx_data.layout
kwargs.get("layout")
if kwargs.get("layout", False)
else packed_fpx_data.layout
)
kwargs["dtype"] = packed_fpx_data.dtype
kwargs["requires_grad"] = False
Expand All @@ -416,12 +512,17 @@ def __tensor_flatten__(self):
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_fpx_data, scale = tensor_data_dict["packed_fpx_data"], tensor_data_dict["scale"]
layout_type, = tensor_attributes
packed_fpx_data, scale = (
tensor_data_dict["packed_fpx_data"],
tensor_data_dict["scale"],
)
(layout_type,) = tensor_attributes
return cls(packed_fpx_data, scale, layout_type)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]:
unpacked_fpx_data = unpack_tc_fpx(self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits)
unpacked_fpx_data = unpack_tc_fpx(
self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits
)
return unpacked_fpx_data, self.scale

@classmethod
Expand All @@ -440,7 +541,9 @@ def from_plain(
bit, M is mantissa bit
"""
assert isinstance(layout_type, FpxTensorCoreLayoutType)
packed_fpx_data = pack_tc_fpx(unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits)
packed_fpx_data = pack_tc_fpx(
unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits
)
return cls(packed_fpx_data, scale, layout_type)

def __repr__(self):
Expand Down Expand Up @@ -478,7 +581,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)
elif func is aten._to_copy.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
func,
args,
kwargs,
args[0]._apply_fn_to_data(
lambda x: x.to(device=kwargs.pop("device", None))
),
)

raise NotImplementedError(
Expand Down
1 change: 0 additions & 1 deletion torchao/dtypes/uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __new__(cls, elem, **kwargs):
)

def __init__(self, elem, **kwargs):

self.elem = elem

@classmethod
Expand Down
Loading

0 comments on commit 1aa8839

Please sign in to comment.