From 6927d289bd9b27e0b02f4a8962353834063c07b0 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Sun, 26 May 2024 22:44:57 +0200 Subject: [PATCH 01/18] Added first bits of Uint2Tensor and BitnetTensor Co-authored-by: James Melvin Ebenezer --- test/dtypes/test_uint2.py | 47 ++++++ torchao/dtypes/__init__.py | 3 + torchao/dtypes/uint2.py | 316 +++++++++++++++++++++++++++++++++++++ 3 files changed, 366 insertions(+) create mode 100644 test/dtypes/test_uint2.py create mode 100644 torchao/dtypes/uint2.py diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py new file mode 100644 index 0000000000..ebc324caf8 --- /dev/null +++ b/test/dtypes/test_uint2.py @@ -0,0 +1,47 @@ +from unittest import main + +import torch +import torch.nn as nn + +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, +) + + +from torchao.dtypes.uint2 import ( + UInt2Tensor, + BitnetTensor +) +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, +) + +def _apply_weight_only_uint2_quant(model): + def fn(mod): + mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False) + return mod + + _replace_with_custom_fn_if_matches_filter( + model, + lambda mod: fn(mod), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + + +class TestUInt2(QuantizationTestCase): + def test_gpu_quant(self): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 16)) + y_ref = m(x) + _apply_weight_only_uint2_quant(m) + y_wo = m(x) + # sqnr = compute_error(y_ref, y_wo) + #opt = torch.compile(m, fullgraph=True, mode="max-autotune") + # make sure it runs + #opt(x) + + +if __name__ == "__main__": + main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 44077dab65..23c0e824b7 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,10 +1,13 @@ from .nf4tensor import NF4Tensor, to_nf4 +from .uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_aq __all__ = [ "NF4Tensor", "to_nf4", + "UInt2Tensor", + "BitnetTensor", "UInt4Tensor" "AffineQuantizedTensor", "to_aq", diff --git a/torchao/dtypes/uint2.py b/torchao/dtypes/uint2.py new file mode 100644 index 0000000000..78a649c20a --- /dev/null +++ b/torchao/dtypes/uint2.py @@ -0,0 +1,316 @@ +import torch +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch.library import impl, Library +from .uint4 import qtensor_lib + + +def down_size(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4) + +def up_size(size): + return (*size[:-1], size[-1] * 4) + +#@torch.compile +def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + first_elements = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1 + second_elements = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1 + third_elements = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1 + fourth_elements = (uint8_data & 0b11).to(torch.int8) - 1 + return torch.stack([first_elements, second_elements, third_elements, fourth_elements], dim=-1).view(up_size(shape)) + +#@torch.compile +def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = ((uint8_data >> 6) & 0b11) + second_elements = ((uint8_data >> 4) & 0b11) + third_elements = ((uint8_data >> 2) & 0b11) + fourth_elements = (uint8_data & 0b11) + return torch.stack((first_elements, second_elements, third_elements, fourth_elements), dim=-1).view(up_size(shape)) + +#packing uint8 +#@torch.compile +def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data + + +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + + +#qtensor_lib = Library("qtensors", "DEF") +qtensor_lib.define( + "quantize_per_tensor_uint2(Tensor input, float scale, int zero_point) -> Tensor" +) + + +@impl(qtensor_lib, "quantize_per_tensor_uint2", "CompositeExplicitAutograd") +def quantize_per_tensor_uint2( + input: torch.Tensor, + scale: float = 1.0, + zero_point: int = 1, +) -> torch.Tensor: + inv_scale = 1.0 / scale + return pack_uint2( + torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 2).to(torch.uint8) + ) + + +qtensor_lib.define( + "dequantize_per_tensor_uint2(Tensor input, float scale, int zero_point) -> Tensor" +) + + +@impl(qtensor_lib, "dequantize_per_tensor_uint2", "CompositeExplicitAutograd") +def dequantize_per_tensor_uint2( + input: torch.Tensor, + scale: float = 1.0, + zero_point: int = 1, +) -> torch.Tensor: + input = unpack_uint2(input) + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale + + +class UInt2Tensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, **kwargs): + assert elem.dtype is torch.uint8 + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + + return torch.Tensor._make_wrapper_subclass( + cls, up_size(elem.shape), dtype=torch.uint2, **kwargs + ) + + def __init__(self, elem, **kwargs): + self.elem = elem + + @classmethod + def from_unpacked(cls, unpacked): + return UInt2Tensor(pack_uint2(unpacked)) + + def tolist(self): + return self.to(torch.uint8).tolist() + + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + assert meta is None + elem = flattened["elem"] + return UInt2Tensor(elem) + + def __hash__(self): + return hash(self.elem) + + def __eq__(self, other): + return torch.equal(self.elem, other.elem) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func is torch.ops.aten.view.default: + self, size = args + size = utils.infer_size(size, self.numel()) + assert not kwargs + # WARNING: views not preserved + return UInt2Tensor(self.elem.reshape(down_size(size))) + elif func is torch.ops.aten.view.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_uint2(self.elem).view(torch.uint8) + return NotImplementedError(f"view {args}") + elif func is torch.ops.aten.to.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_uint2(self.elem).view(torch.uint8) + return NotImplementedError(f"to {args}") + elif func is torch.ops.aten.eq.Tensor: + args = pytree.tree_map_only( + UInt2Tensor, lambda x: x.elem.view(torch.uint8), args + ) + kwargs = pytree.tree_map_only( + UInt2Tensor, lambda x: x.elem.view(torch.uint8), kwargs + ) + return torch.ops.aten.eq.Tensor(*args, **kwargs) + elif func is torch.ops.aten._to_copy.default: + (self,) = args + if kwargs == {"dtype": torch.uint8}: + return unpack_uint2(self.elem).view(self.shape) # no wrap + else: + raise NotImplementedError(f"_to_copy {kwargs}") + elif func is torch.ops.aten.unbind.int: + # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to + # create four tensors containing one element each. But we can't + # do this with uint2 because such a tensor's size is not divisible + # by bytes. What I am going to do instead is promote to uint8 + # when this happens + self, dim = fill_defaults(args, 2, [0]) + if dim != self.dim() - 1: + raise NotImplementedError(f"unbind dim={dim}") + else: + # We're unbinding the last dimension, need to promote + return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind( + dim + ) + elif func is torch.ops.aten.select.int: + self, dim, index = args + if dim != self.dim() - 1: + return UInt2Tensor(torch.ops.aten.select.int(self.elem, dim, index)) + else: + raise NotImplementedError(f"select dim={dim}") + elif func is torch.ops.aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == self.dim() - 1: + # hard case + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % 4 == 0, start + assert end >= self.shape[dim] or end % 4 == 0, end + return UInt2Tensor( + torch.ops.aten.slice.Tensor(self.elem, dim, start // 4, end // 4, 1) + ) + else: + # easy case + return UInt2Tensor( + torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step) + ) + elif func is torch.ops.aten.t.default: + # assert False, "transpose is not properly implemented currently" + (self,) = args + unpacked = unpack_uint2(self.elem) + transposed = torch.ops.aten.t.default(unpacked) + transposed_and_packed = pack_uint2(transposed) + return UInt2Tensor(transposed_and_packed) + elif func is torch.ops.aten.transpose_copy.int: + self, dim0, dim1 = args + unpacked = unpack_uint2(self.elem).view(self.shape) + transposed = torch.ops.aten.transpose_copy.int(unpacked, dim0, dim1) + transposed_and_packed = pack_uint2(transposed) + return UInt2Tensor(transposed_and_packed) + elif func is torch.ops.aten.as_strided.default: + # size, stride, storage_offset are referring to tensor elements, not physical bytes + self, size, stride, storage_offset = args + size = down_size(size) + + new_stride = [] + for s in stride: + if s != 1: + # since two int4 equals to 1 uint8 + new_stride.append(s // 4) + else: + new_stride.append(s) + stride = new_stride + + storage_offset //= 4 + return UInt2Tensor( + torch.ops.aten.as_strided.default( + self.elem, size, stride, storage_offset + ) + ) + + raise NotImplementedError(f"{func}") + + __torch_function__ = torch._C._disabled_torch_function_impl + + +def _quantize_int2(x: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + quant = x.sign() + 1 + + if target_dtype == torch.uint2: + quant = BitnetTensor.from_unpacked( + quant.to(torch.uint8), + ) + else: + quant = quant.to(target_dtype) + + return quant + + +class BitnetTensor(UInt2Tensor): + @staticmethod + def __new__(cls, elem, **kwargs): + return super().__new__(cls, elem, **kwargs) + + def __init__(self, elem, **kwargs): + super().__init__(elem, **kwargs) + + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + assert meta is None + elem = flattened["elem"] + return BitnetTensor(elem) + + @classmethod + # inconsistently. + def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": + return cls(pack_uint2(unpacked)) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func is torch.ops.aten.mm.default: + x, weight = args + y = torch.mm(x, weight.to(torch.uint8).to(x.dtype)) + return y + elif func is torch.ops.aten.addmm.default: + bias, x, weight = args + #x_view = x.view(-1, x.shape[-1]) # not clear why + x_view = x + y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) + #y = y.reshape(*x.shape[:-1], -1) + if bias is not None: + y += bias + return y + elif func is torch.ops.aten.t.default: + # TODO: add proper support for transpose + (self,) = args + unpacked = unpack_uint2(self.elem) + transposed = torch.ops.aten.t.default(unpacked) + return BitnetTensor.from_unpacked( + transposed + ) + elif func is torch.ops.aten.detach.default: + (self,) = args + return self + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def from_float(cls, w: torch.Tensor): + w_int2 = _quantize_int2( + w, torch.uint2 + ).to(device=w.device) + return w_int2 From 0d85b0600f7639d8e058655d0473dcbd1bbf8659 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Mon, 27 May 2024 01:13:39 +0200 Subject: [PATCH 02/18] add conversion to standard signed and unsigned dtypes --- torchao/dtypes/uint2.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/torchao/dtypes/uint2.py b/torchao/dtypes/uint2.py index 78a649c20a..4c6d81f8c5 100644 --- a/torchao/dtypes/uint2.py +++ b/torchao/dtypes/uint2.py @@ -166,8 +166,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return torch.ops.aten.eq.Tensor(*args, **kwargs) elif func is torch.ops.aten._to_copy.default: (self,) = args - if kwargs == {"dtype": torch.uint8}: - return unpack_uint2(self.elem).view(self.shape) # no wrap + dtype = kwargs["dtype"] + if dtype == torch.uint8: + return unpack_uint2(self.elem).view(self.shape) + if dtype in (torch.uint16, torch.uint32, torch.uint64): + return self.to(torch.uint8).to(dtype) else: raise NotImplementedError(f"_to_copy {kwargs}") elif func is torch.ops.aten.unbind.int: @@ -284,14 +287,11 @@ def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": def __torch_dispatch__(cls, func, types, args, kwargs=None): if func is torch.ops.aten.mm.default: x, weight = args - y = torch.mm(x, weight.to(torch.uint8).to(x.dtype)) + y = torch.mm(x, weight.to(torch.int8).to(x.dtype)) return y elif func is torch.ops.aten.addmm.default: bias, x, weight = args - #x_view = x.view(-1, x.shape[-1]) # not clear why - x_view = x - y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) - #y = y.reshape(*x.shape[:-1], -1) + y = torch.mm(x, weight.to(torch.int8).to(x.dtype)) if bias is not None: y += bias return y @@ -306,6 +306,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): elif func is torch.ops.aten.detach.default: (self,) = args return self + elif func is torch.ops.aten.to.dtype: + self, dtype = args + if dtype == torch.int8: + return unpack_uint2(self.elem).view(torch.int8) - 1 + elif func is torch.ops.aten._to_copy.default: + (self,) = args + dtype = kwargs["dtype"] + if dtype == torch.int8: + return unpack_uint2(self.elem).view(self.shape).view(torch.int8) - 1 + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return self.to(torch.int8).to(dtype) return super().__torch_dispatch__(func, types, args, kwargs) @classmethod From f64e4575f9945c8915a80775f6ad3faf7f5c526a Mon Sep 17 00:00:00 2001 From: James Melvin Date: Tue, 28 May 2024 12:59:38 +0530 Subject: [PATCH 03/18] added triton kernel for pack and unpack --- torchao/dtypes/uint2.py | 143 ++++++++++++++++++++++++++++++---------- 1 file changed, 110 insertions(+), 33 deletions(-) diff --git a/torchao/dtypes/uint2.py b/torchao/dtypes/uint2.py index 4c6d81f8c5..696ad7d443 100644 --- a/torchao/dtypes/uint2.py +++ b/torchao/dtypes/uint2.py @@ -12,36 +12,114 @@ def down_size(size): def up_size(size): return (*size[:-1], size[-1] * 4) -#@torch.compile -def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - first_elements = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1 - second_elements = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1 - third_elements = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1 - fourth_elements = (uint8_data & 0b11).to(torch.int8) - 1 - return torch.stack([first_elements, second_elements, third_elements, fourth_elements], dim=-1).view(up_size(shape)) - -#@torch.compile -def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = ((uint8_data >> 6) & 0b11) - second_elements = ((uint8_data >> 4) & 0b11) - third_elements = ((uint8_data >> 2) & 0b11) - fourth_elements = (uint8_data & 0b11) - return torch.stack((first_elements, second_elements, third_elements, fourth_elements), dim=-1).view(up_size(shape)) - -#packing uint8 -#@torch.compile -def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - # converting to uint8 for operations - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) - return packed_data +if torch.cuda.is_available() and torch.utils._triton.has_triton(): + import triton + import triton.language as tl + + @triton.jit + def triton_unpack_uint8_to_trinary2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + uint8_value = tl.load(uint8_data + offsets, mask=mask) + + first_elements = ((uint8_value >> 6) & 0b11).to(tl.int8) - 1 + second_elements = ((uint8_value >> 4) & 0b11).to(tl.int8) - 1 + third_elements = ((uint8_value >> 2) & 0b11).to(tl.int8) - 1 + fourth_elements = (uint8_value & 0b11).to(tl.int8) - 1 + + tl.store(output + offsets * 4 + 0, first_elements, mask=mask) + tl.store(output + offsets * 4 + 1, second_elements, mask=mask) + tl.store(output + offsets * 4 + 2, third_elements, mask=mask) + tl.store(output + offsets * 4 + 3, fourth_elements, mask=mask) + + def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + output = torch.empty(up_size(shape), dtype=torch.int8, device=uint8_data.device) + n_elements = uint8_data.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + triton_unpack_uint8_to_trinary2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) + return output + + @triton.jit + def triton_unpack_uint2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + uint8_value = tl.load(uint8_data + offsets, mask=mask) + + first_elements = (uint8_value >> 6) & 0b11 + second_elements = (uint8_value >> 4) & 0b11 + third_elements = (uint8_value >> 2) & 0b11 + fourth_elements = uint8_value & 0b11 + + tl.store(output + offsets * 4 + 0, first_elements, mask=mask) + tl.store(output + offsets * 4 + 1, second_elements, mask=mask) + tl.store(output + offsets * 4 + 2, third_elements, mask=mask) + tl.store(output + offsets * 4 + 3, fourth_elements, mask=mask) + + def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + output = torch.empty(up_size(shape), dtype=torch.uint8, device=uint8_data.device) + n_elements = uint8_data.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + triton_unpack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) + return output + + @triton.jit + def triton_pack_uint2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets * 4 < n_elements + + first_elements = tl.load(uint8_data + offsets * 4 + 0, mask=mask) + second_elements = tl.load(uint8_data + offsets * 4 + 1, mask=mask) + third_elements = tl.load(uint8_data + offsets * 4 + 2, mask=mask) + fourth_elements = tl.load(uint8_data + offsets * 4 + 3, mask=mask) + + packed_data = (first_elements << 6) | (second_elements << 4) | (third_elements << 2) | fourth_elements + + tl.store(output + offsets, packed_data, mask=mask) + + def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + n_elements = uint8_data.numel() + packed_shape = down_size(shape) + output = torch.empty(packed_shape, dtype=torch.uint8, device=uint8_data.device) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE'] * 4),) + triton_pack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) + return output + +else: + #@torch.compile + def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + first_elements = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1 + second_elements = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1 + third_elements = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1 + fourth_elements = (uint8_data & 0b11).to(torch.int8) - 1 + return torch.stack([first_elements, second_elements, third_elements, fourth_elements], dim=-1).view(up_size(shape)) + + #@torch.compile + def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = ((uint8_data >> 6) & 0b11) + second_elements = ((uint8_data >> 4) & 0b11) + third_elements = ((uint8_data >> 2) & 0b11) + fourth_elements = (uint8_data & 0b11) + return torch.stack((first_elements, second_elements, third_elements, fourth_elements), dim=-1).view(up_size(shape)) + + #packing uint8 + #@torch.compile + def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data def fill_defaults(args, n, defaults_tail): @@ -300,9 +378,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): (self,) = args unpacked = unpack_uint2(self.elem) transposed = torch.ops.aten.t.default(unpacked) - return BitnetTensor.from_unpacked( - transposed - ) + return BitnetTensor.from_unpacked(transposed) elif func is torch.ops.aten.detach.default: (self,) = args return self @@ -325,3 +401,4 @@ def from_float(cls, w: torch.Tensor): w, torch.uint2 ).to(device=w.device) return w_int2 + From 8b14c178033acecb9c2423b490f58c4c6381be04 Mon Sep 17 00:00:00 2001 From: James Melvin Priyarajan Date: Tue, 28 May 2024 08:54:27 +0000 Subject: [PATCH 04/18] fix: test cases and device allocation for triton kernels --- test/dtypes/test_uint2.py | 12 +++++------- torchao/dtypes/uint2.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index ebc324caf8..d03d23abc8 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -4,13 +4,10 @@ import torch.nn as nn from torch.testing._internal.common_quantization import ( - NodeSpec as ns, QuantizationTestCase, ) - from torchao.dtypes.uint2 import ( - UInt2Tensor, BitnetTensor ) from torchao.quantization.quant_api import ( @@ -31,16 +28,17 @@ def fn(mod): class TestUInt2(QuantizationTestCase): def test_gpu_quant(self): + device = 'cuda' if torch.cuda.is_available() else 'cpu' for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: - x = torch.randn(*x_shape) - m = nn.Sequential(nn.Linear(4, 16)) + x = torch.randn(*x_shape).to(device) + m = nn.Sequential(nn.Linear(4, 16)).to(device) y_ref = m(x) _apply_weight_only_uint2_quant(m) y_wo = m(x) # sqnr = compute_error(y_ref, y_wo) - #opt = torch.compile(m, fullgraph=True, mode="max-autotune") + # opt = torch.compile(m, fullgraph=True, mode="max-autotune") # make sure it runs - #opt(x) + # opt(x) if __name__ == "__main__": diff --git a/torchao/dtypes/uint2.py b/torchao/dtypes/uint2.py index 696ad7d443..e2b72dcdbe 100644 --- a/torchao/dtypes/uint2.py +++ b/torchao/dtypes/uint2.py @@ -34,8 +34,9 @@ def triton_unpack_uint8_to_trinary2(uint8_data, output, n_elements, BLOCK_SIZE: tl.store(output + offsets * 4 + 3, fourth_elements, mask=mask) def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: + uint8_data = uint8_data.to('cuda') shape = uint8_data.shape - output = torch.empty(up_size(shape), dtype=torch.int8, device=uint8_data.device) + output = torch.empty(up_size(shape), dtype=torch.int8, device='cuda') n_elements = uint8_data.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) triton_unpack_uint8_to_trinary2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) @@ -59,8 +60,9 @@ def triton_unpack_uint2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr tl.store(output + offsets * 4 + 3, fourth_elements, mask=mask) def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + uint8_data = uint8_data.to('cuda') shape = uint8_data.shape - output = torch.empty(up_size(shape), dtype=torch.uint8, device=uint8_data.device) + output = torch.empty(up_size(shape), dtype=torch.uint8, device='cuda') n_elements = uint8_data.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) triton_unpack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) @@ -81,11 +83,12 @@ def triton_pack_uint2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): tl.store(output + offsets, packed_data, mask=mask) def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + uint8_data = uint8_data.to('cuda') shape = uint8_data.shape assert shape[-1] % 4 == 0 n_elements = uint8_data.numel() packed_shape = down_size(shape) - output = torch.empty(packed_shape, dtype=torch.uint8, device=uint8_data.device) + output = torch.empty(packed_shape, dtype=torch.uint8, device='cuda') grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE'] * 4),) triton_pack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) return output @@ -249,6 +252,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return unpack_uint2(self.elem).view(self.shape) if dtype in (torch.uint16, torch.uint32, torch.uint64): return self.to(torch.uint8).to(dtype) + if dtype == torch.uint2: + return self else: raise NotImplementedError(f"_to_copy {kwargs}") elif func is torch.ops.aten.unbind.int: @@ -365,18 +370,18 @@ def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": def __torch_dispatch__(cls, func, types, args, kwargs=None): if func is torch.ops.aten.mm.default: x, weight = args - y = torch.mm(x, weight.to(torch.int8).to(x.dtype)) + y = torch.mm(x, weight.to(torch.int8).to(x.device).to(x.dtype)) return y elif func is torch.ops.aten.addmm.default: bias, x, weight = args - y = torch.mm(x, weight.to(torch.int8).to(x.dtype)) + y = torch.mm(x, weight.to(torch.int8).to(x.device).to(x.dtype)) if bias is not None: y += bias return y elif func is torch.ops.aten.t.default: # TODO: add proper support for transpose (self,) = args - unpacked = unpack_uint2(self.elem) + unpacked = unpack_uint2(self.elem).to(self.device) transposed = torch.ops.aten.t.default(unpacked) return BitnetTensor.from_unpacked(transposed) elif func is torch.ops.aten.detach.default: @@ -393,6 +398,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return unpack_uint2(self.elem).view(self.shape).view(torch.int8) - 1 elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): return self.to(torch.int8).to(dtype) + elif dtype == torch.uint2: + return self return super().__torch_dispatch__(func, types, args, kwargs) @classmethod From fcd7c08d8a1462c5553c4aef6bbbfbc4832ca42a Mon Sep 17 00:00:00 2001 From: James Melvin Priyarajan Date: Sat, 1 Jun 2024 12:44:02 +0000 Subject: [PATCH 05/18] fix: moved uint2 to prototype folder --- test/dtypes/test_uint2.py | 21 +++------------------ torchao/dtypes/__init__.py | 4 +--- torchao/{ => prototype}/dtypes/uint2.py | 5 +++-- 3 files changed, 7 insertions(+), 23 deletions(-) rename torchao/{ => prototype}/dtypes/uint2.py (98%) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index d03d23abc8..ba2691a4af 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -1,18 +1,9 @@ from unittest import main - import torch import torch.nn as nn - -from torch.testing._internal.common_quantization import ( - QuantizationTestCase, -) - -from torchao.dtypes.uint2 import ( - BitnetTensor -) -from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, -) +from torch.testing._internal.common_quantization import QuantizationTestCase +from torchao.prototype.dtypes.uint2 import BitnetTensor +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter def _apply_weight_only_uint2_quant(model): def fn(mod): @@ -25,7 +16,6 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) - class TestUInt2(QuantizationTestCase): def test_gpu_quant(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -35,11 +25,6 @@ def test_gpu_quant(self): y_ref = m(x) _apply_weight_only_uint2_quant(m) y_wo = m(x) - # sqnr = compute_error(y_ref, y_wo) - # opt = torch.compile(m, fullgraph=True, mode="max-autotune") - # make sure it runs - # opt(x) - if __name__ == "__main__": main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 23c0e824b7..e0887b71ff 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,13 +1,11 @@ from .nf4tensor import NF4Tensor, to_nf4 -from .uint2 import UInt2Tensor, BitnetTensor +# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_aq __all__ = [ "NF4Tensor", "to_nf4", - "UInt2Tensor", - "BitnetTensor", "UInt4Tensor" "AffineQuantizedTensor", "to_aq", diff --git a/torchao/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py similarity index 98% rename from torchao/dtypes/uint2.py rename to torchao/prototype/dtypes/uint2.py index e2b72dcdbe..f1be0cbce5 100644 --- a/torchao/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -2,7 +2,7 @@ import torch._prims_common as utils import torch.utils._pytree as pytree from torch.library import impl, Library -from .uint4 import qtensor_lib +from ...dtypes.uint4 import qtensor_lib def down_size(size): @@ -94,6 +94,7 @@ def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: return output else: + # TODO: torch compile issue https://github.com/pytorch/pytorch/issues/127374 is fixed #@torch.compile def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: # since we are using uint8 we will decode 4 entries per byte @@ -150,7 +151,7 @@ def fill_defaults(args, n, defaults_tail): return r -#qtensor_lib = Library("qtensors", "DEF") +# qtensor_lib = Library("qtensors", "DEF") qtensor_lib.define( "quantize_per_tensor_uint2(Tensor input, float scale, int zero_point) -> Tensor" ) From 30d95a185702cf8bd5cab262732eb0a4315a3818 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Sat, 1 Jun 2024 20:12:53 +0200 Subject: [PATCH 06/18] Add packing and unpacking functions for uint{2,3,4,5,6,7}. Co-authored-by: James Melvin Ebenezer --- torchao/prototype/dtypes/uint_small.py | 359 +++++++++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 torchao/prototype/dtypes/uint_small.py diff --git a/torchao/prototype/dtypes/uint_small.py b/torchao/prototype/dtypes/uint_small.py new file mode 100644 index 0000000000..b9e3ca88e6 --- /dev/null +++ b/torchao/prototype/dtypes/uint_small.py @@ -0,0 +1,359 @@ +import torch + + +def down_size_uint2(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4) + + +def up_size_uint2(size): + return (*size[:-1], size[-1] * 4) + + +def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = (uint8_data >> 6) & 0b11 + second_elements = (uint8_data >> 4) & 0b11 + third_elements = (uint8_data >> 2) & 0b11 + fourth_elements = uint8_data & 0b11 + return torch.stack( + (first_elements, second_elements, third_elements, fourth_elements), dim=-1 + ).view(up_size_uint2(shape)) + + +def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + """pack lowest 2 bits of 2 uint8 -> 1 uint8""" + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = ( + (uint8_data[::4] & 0b11) << 6 + | (uint8_data[1::4] & 0b11) << 4 + | (uint8_data[2::4] & 0b11) << 2 + | (uint8_data[3::4] & 0b11) + ).view(down_size_uint2(shape)) + return packed_data + + +def down_size_uint3(size): + assert size[-1] % 8 == 0, f"{size} last dim not divisible by eight" + return (*size[:-1], size[-1] // 8 * 3) + + +def up_size_uint3(size): + assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" + return (*size[:-1], size[-1] // 3 * 8) + + +def unpack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: + """ + 3 -> 8 + 01234567|01234567|01234567 + AAABBBCC|CDDDEEEF|FFGGGHHH + """ + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + + return torch.stack( + ( + (uint8_data[::3] >> 5) & 0b111, + (uint8_data[::3] >> 2) & 0b111, + (uint8_data[::3] & 0b11) << 1 | (uint8_data[1::3] >> 7) & 0b1, + (uint8_data[1::3] >> 4) & 0b111, + (uint8_data[1::3] >> 1) & 0b111, + (uint8_data[1::3] & 0b1) << 2 | (uint8_data[2::3] >> 6) & 0b11, + (uint8_data[2::3] >> 3) & 0b111, + uint8_data[2::3] & 0b111, + ), + dim=-1, + ).view(up_size_uint3(shape)) + + +def pack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: + """ + 8 -> 3 + 01234567|01234567|01234567 + AAABBBCC|CDDDEEEF|FFGGGHHH + """ + + shape = uint8_data.shape + assert shape[-1] % 8 == 0 + uint8_data = uint8_data.contiguous().view(-1) + + packed_data = torch.stack( + ( + ((uint8_data[::8] & 0b111) << 5 | (uint8_data[1::8] & 0b111) << 2 | (uint8_data[2::8] & 0b111) >> 1), + ((uint8_data[2::8] & 0b1) << 7 | (uint8_data[3::8] & 0b111) << 4 | (uint8_data[4::8] & 0b111) << 1 | ((uint8_data[5::8] >> 2) & 1)), + ((uint8_data[5::8] & 0b11) << 6 | (uint8_data[6::8] & 0b111) << 3 | (uint8_data[7::8] & 0b111)), + ), + dim=-1 + ).view(down_size_uint3(shape)) + + return packed_data + + +def down_size_uint4(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + + +def up_size_uint4(size): + return (*size[:-1], size[-1] * 2) + + +def unpack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = (uint8_data >> 4) & 0b1111 + second_elements = uint8_data & 0b1111 + return torch.stack((first_elements, second_elements), dim=-1).view( + up_size_uint4(shape) + ) + + +def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::2] << 4 | (uint8_data[1::2] & 0b1111)).view( + down_size_uint4(shape) + ) + return packed_data + + +def down_size_uint5(size): + assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" + return (*size[:-1], size[-1] // 8 * 5) + + +def up_size_uint5(size): + assert size[-1] % 5 == 0, f"{size} last dim not divisible by 5" + return (*size[:-1], size[-1] // 5 * 8) + + +def pack_uint5(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack the 5 lowest bits of 8 input bytes into 5 bytes + + 8 -> 5 + 01234567|01234567|01234567|01234567|01234567 + AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH + + The packing pattern: + - First byte: (A0 A1 A2 A3 A4 B0 B1 B2) + - Second byte: (B3 B4 C0 C1 C2 C3 C4 D0) + - Third byte: (D1 D2 D3 D4 E0 E1 E2 E3) + - Fourth byte: (E4 F0 F1 F2 F3 F4 G0 G1) + - Fifth byte: (G2 G3 G4 H0 H1 H2 H3 H4) + """ + shape = uint8_data.shape + assert ( + shape[-1] % 8 == 0 + ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" + + uint8_data = uint8_data.contiguous().view(-1, 8) + + packed_data = torch.stack( + ( + ((uint8_data[:, 0] & 0b00011111) << 3) | ((uint8_data[:, 1] & 0b00011100) >> 2), + ((uint8_data[:, 1] & 0b00000011) << 6) | ((uint8_data[:, 2] & 0b00011111) << 1) | ((uint8_data[:, 3] & 0b10000) >> 4), + ((uint8_data[:, 3] & 0b00001111) << 4) | ((uint8_data[:, 4] & 0b00011110) >> 1), + ((uint8_data[:, 4] & 0b00000001) << 7) | ((uint8_data[:, 5] & 0b00011111) << 2) | ((uint8_data[:, 6] & 0b0011000) >> 3), + ((uint8_data[:, 6] & 0b00000111) << 5) | (uint8_data[:, 7] & 0b00011111), + ), + dim=-1, + ).view(down_size_uint5(shape)) + + return packed_data + + +def unpack_uint5(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack the 5 bytes into the 5 lowest bits of 8 bytes + 01234567|01234567|01234567|01234567|01234567 + AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH + """ + shape = packed_data.shape + assert ( + shape[-1] % 5 == 0 + ), f"Input last dimension should be divisible by 5, but got {shape[-1]}" + + packed_data = packed_data.contiguous().view(-1, 5) + + unpacked_data = torch.stack( + ( + ((packed_data[:, 0] >> 3) & 0b00011111), + ((packed_data[:, 0] & 0b00000111) << 2) | ((packed_data[:, 1] >> 6) & 0b00000011), + ((packed_data[:, 1] >> 1) & 0b00011111), + ((packed_data[:, 1] & 0b00000001) << 4) | ((packed_data[:, 2] >> 4) & 0b00001111), + ((packed_data[:, 2] & 0b00001111) << 1) | ((packed_data[:, 3] >> 7) & 0b00000001), + ((packed_data[:, 3] >> 2) & 0b00011111), + ((packed_data[:, 3] & 0b00000011) << 3) | ((packed_data[:, 4] >> 5) & 0b00000111), + packed_data[:, 4] & 0b00011111, + ), + dim=-1, + ).view(up_size_uint5(shape)) + + return unpacked_data + + +def down_size_uint6(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4 * 3) + + +def up_size_uint6(size): + assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" + return (*size[:-1], size[-1] // 3 * 4) + + +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack the 6 lowest bits of 4 input bytes into 3 bytes + + 4 -> 3 + 01234567|01234567|01234567 + AAAAAABB|BBBBCCCC|CCDDDDDD + + The packing pattern: + - First byte: (A0 A1 A2 A3 A4 A5 B0 B1) + - Second byte: (B2 B3 B4 B5 C0 C1 C2 C3) + - Third byte: (C4 C5 D0 D1 D2 D3 D4 D5) + """ + shape = uint8_data.shape + assert ( + shape[-1] % 4 == 0 + ), f"Input last dimension should be divisible by 4, but got {shape[-1]}" + + uint8_data = uint8_data.contiguous().view(-1, 4) + + packed_data = torch.stack( + ( + ((uint8_data[:, 0] & 0b00111111) << 2) | ((uint8_data[:, 1] >> 4) & 0b00000011), + ((uint8_data[:, 1] & 0b00001111) << 4) | ((uint8_data[:, 2] >> 2) & 0b00001111), + ((uint8_data[:, 2] & 0b00000011) << 6) | (uint8_data[:, 3] & 0b00111111), + ), + dim=-1, + ).view(down_size_uint6(shape)) + + return packed_data + + +def unpack_uint6(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack the 3 bytes into the 6 lowest bits of 4 outputs + 01234567|01234567|01234567 + AAAAAABB|BBBBCCCC|CCDDDDDD + """ + shape = packed_data.shape + assert ( + shape[-1] % 3 == 0 + ), f"Input last dimension should be divisible by 3, but got {shape[-1]}" + + packed_data = packed_data.contiguous().view(-1, 3) + + unpacked_data = torch.stack( + ( + (packed_data[:, 0] >> 2) & 0b00111111, + ((packed_data[:, 0] & 0b00000011) << 4) | ((packed_data[:, 1] >> 4) & 0b00001111), + ((packed_data[:, 1] & 0b00001111) << 2) | ((packed_data[:, 2] >> 6) & 0b00000011), + packed_data[:, 2] & 0b00111111, + ), + dim=-1, + ).view(up_size_uint6(shape)) + + return unpacked_data + + +def down_size_uint7(size): + assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" + return (*size[:-1], size[-1] // 8 * 7) + + +def up_size_uint7(size): + assert size[-1] % 7 == 0, f"{size} last dim not divisible by 7" + return (*size[:-1], size[-1] // 7 * 8) + + +def pack_uint7(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack the 7 lowest bits of 8 input bytes into 7 bytes + + 8 -> 7 + 01234567|01234567|01234567|01234567|01234567|01234567|01234567 + AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH + + The packing pattern: + - First byte: (A0 A1 A2 A3 A4 A5 A6 B0) + - Second byte: (B1 B2 B3 B4 B5 B6 C0 C1) + - Third byte: (C2 C3 C4 C5 C6 D0 D1 D2) + - Fourth byte: (D3 D4 D5 D6 E0 E1 E2 E3) + - Fifth byte: (E4 E5 E6 F0 F1 F2 F3 F4) + - Sixth byte: (F5 F6 G0 G1 G2 G3 G4 G5) + - Seventh byte:(G6 H0 H1 H2 H3 H4 H5 H6) + """ + shape = uint8_data.shape + assert ( + shape[-1] % 8 == 0 + ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" + + uint8_data = uint8_data.contiguous().view(-1, 8) + + packed_data = torch.stack( + ( + ((uint8_data[:, 0] & 0b01111111) << 1) | ((uint8_data[:, 1] >> 6) & 0b00000001), + ((uint8_data[:, 1] & 0b00111111) << 2) | ((uint8_data[:, 2] >> 5) & 0b00000011), + ((uint8_data[:, 2] & 0b00011111) << 3) | ((uint8_data[:, 3] >> 4) & 0b00000111), + ((uint8_data[:, 3] & 0b00001111) << 4) | ((uint8_data[:, 4] >> 3) & 0b00001111), + ((uint8_data[:, 4] & 0b00000111) << 5) | ((uint8_data[:, 5] >> 2) & 0b00011111), + ((uint8_data[:, 5] & 0b00000011) << 6) | ((uint8_data[:, 6] >> 1) & 0b00111111), + ((uint8_data[:, 6] & 0b00000001) << 7) | ((uint8_data[:, 7] >> 0) & 0b01111111), + ), + dim=-1, + ).view(down_size_uint7(shape)) + + return packed_data + + +def unpack_uint7(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack the 7 bytes into the 7 lowest bits of 8 bytes + 01234567|01234567|01234567|01234567|01234567|01234567|01234567 + AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH + """ + shape = packed_data.shape + assert ( + shape[-1] % 7 == 0 + ), f"Input last dimension should be divisible by 7, but got {shape[-1]}" + + packed_data = packed_data.contiguous().view(-1, 7) + + unpacked_data = torch.stack( + ( + (packed_data[:, 0] >> 1) & 0b01111111, + ((packed_data[:, 0] & 0b00000001) << 6) | ((packed_data[:, 1] >> 2) & 0b01111111), + ((packed_data[:, 1] & 0b00000011) << 5) | ((packed_data[:, 2] >> 3) & 0b01111111), + ((packed_data[:, 2] & 0b00000111) << 4) | ((packed_data[:, 3] >> 4) & 0b01111111), + ((packed_data[:, 3] & 0b00001111) << 3) | ((packed_data[:, 4] >> 5) & 0b01111111), + ((packed_data[:, 4] & 0b00011111) << 2) | ((packed_data[:, 5] >> 6) & 0b01111111), + ((packed_data[:, 5] & 0b00111111) << 1) | ((packed_data[:, 6] >> 7) & 0b01111111), + packed_data[:, 6] & 0b01111111, + ), + dim=-1, + ).view(up_size_uint7(shape)) + + return unpacked_data + + +def test_uint_small_range(pack_fn, unpack_fn, bit_count): + x = torch.arange(0, 256, dtype=torch.uint8) + y = pack_fn(x) + z = unpack_fn(y) + k = z.view(-1, 2 ** bit_count) + check = torch.arange(0, 2 ** bit_count, dtype=torch.uint8).repeat(k.size(0), 1) + assert torch.all(k == check) + + +if __name__ == "__main__": + test_uint_small_range(pack_uint2, unpack_uint2, 2) + test_uint_small_range(pack_uint3, unpack_uint3, 3) + test_uint_small_range(pack_uint4, unpack_uint4, 4) + test_uint_small_range(pack_uint5, unpack_uint5, 5) + test_uint_small_range(pack_uint6, unpack_uint6, 6) + test_uint_small_range(pack_uint7, unpack_uint7, 7) From a07148752ec3bad33936581d40ba9d33be2e04c1 Mon Sep 17 00:00:00 2001 From: James Melvin Date: Mon, 3 Jun 2024 19:31:41 +0530 Subject: [PATCH 07/18] housekeeping: renamed uint_small to uintgen and simple comments --- torchao/prototype/dtypes/uint2.py | 3 +++ torchao/prototype/dtypes/{uint_small.py => uintgen.py} | 3 +++ 2 files changed, 6 insertions(+) rename torchao/prototype/dtypes/{uint_small.py => uintgen.py} (98%) diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index f1be0cbce5..eac3a54c9b 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -4,6 +4,9 @@ from torch.library import impl, Library from ...dtypes.uint4 import qtensor_lib +""" +Converts a tensor of uint8 to a tensor of uint2 mostly applicable for bitnet 1.58 +""" def down_size(size): assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" diff --git a/torchao/prototype/dtypes/uint_small.py b/torchao/prototype/dtypes/uintgen.py similarity index 98% rename from torchao/prototype/dtypes/uint_small.py rename to torchao/prototype/dtypes/uintgen.py index b9e3ca88e6..d16feada68 100644 --- a/torchao/prototype/dtypes/uint_small.py +++ b/torchao/prototype/dtypes/uintgen.py @@ -1,5 +1,8 @@ import torch +""" +Contains generic functions to pack and unpack uint8 tensors into uint2, uint3, uint4, uint5, uint6, and uint7 tensors. +""" def down_size_uint2(size): assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" From f0d598232058be6354e0c4042c9cbe34c7a1bf4c Mon Sep 17 00:00:00 2001 From: Z <48565901+CoffeeVampir3@users.noreply.github.com> Date: Mon, 3 Jun 2024 20:19:59 -0600 Subject: [PATCH 08/18] Update uint2.py Added several operations for UInt2Tensor, still needs work. --- torchao/prototype/dtypes/uint2.py | 37 ++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index eac3a54c9b..c1f11d9ba2 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -297,7 +297,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step) ) elif func is torch.ops.aten.t.default: - # assert False, "transpose is not properly implemented currently" (self,) = args unpacked = unpack_uint2(self.elem) transposed = torch.ops.aten.t.default(unpacked) @@ -329,13 +328,45 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): self.elem, size, stride, storage_offset ) ) - - raise NotImplementedError(f"{func}") + ## /clamp_/round <- values + elif func is torch.ops.aten.min: + self, dim, keepdim = fill_defaults(args, 0, False) + unpacked = unpack_uint2(self.elem).view(self.shape) + min_result = torch.ops.aten.min(unpacked, dim, keepdim) + return min_result + elif func is torch.ops.aten.max: + self, dim, keepdim = fill_defaults(args, 0, False) + unpacked = unpack_uint2(self.elem).view(self.shape) + max_result = torch.ops.aten.max(unpacked, dim, keepdim) + return max_result + elif func is torch.ops.aten.amin: + self, dim, keepdim = fill_defaults(args, 0, False) + unpacked = unpack_uint2(self.elem).view(self.shape) + min_result = torch.ops.aten.amin(unpacked, dim, keepdim) + return min_result + elif func is torch.ops.aten.amax: + self, dim, keepdim = fill_defaults(args, 0, False) + unpacked = unpack_uint2(self.elem).view(self.shape) + max_result = torch.ops.aten.amax(unpacked, dim, keepdim) + return max_result + elif func is torch.ops.aten.clamp: + self, min_v, max_v = fill_defaults(args, None, None) + unpacked = unpack_uint2(self.elem).view(self.shape) + clamped_result = torch.ops.aten.clamp(unpacked, min_v, max_v) + return clamped_result + elif func is torch.ops.aten.abs: + self = args + unpacked = unpack_uint2(self.elem).view(self.shape) + abs_result = torch.ops.aten.abs(unpacked) + return abs_result + else: + raise NotImplementedError(f"{func}") __torch_function__ = torch._C._disabled_torch_function_impl def _quantize_int2(x: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + # should this just be addition or .sign? Is the .sign a trick here or could this be removed? quant = x.sign() + 1 if target_dtype == torch.uint2: From 13fa9d80e3920f92fd47e353f00cee4018c02f4c Mon Sep 17 00:00:00 2001 From: James Melvin Date: Wed, 5 Jun 2024 14:12:01 +0530 Subject: [PATCH 09/18] added pytest ,compile tests and some cleanup --- test/dtypes/test_uint2.py | 31 +++++++++++++++------------- torchao/prototype/dtypes/__init__.py | 8 +++++++ torchao/prototype/dtypes/uintgen.py | 2 +- 3 files changed, 26 insertions(+), 15 deletions(-) create mode 100644 torchao/prototype/dtypes/__init__.py diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index ba2691a4af..906135ee1e 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -1,8 +1,7 @@ -from unittest import main +import pytest import torch import torch.nn as nn -from torch.testing._internal.common_quantization import QuantizationTestCase -from torchao.prototype.dtypes.uint2 import BitnetTensor +from torchao.prototype.dtypes import BitnetTensor from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter def _apply_weight_only_uint2_quant(model): @@ -16,15 +15,19 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -class TestUInt2(QuantizationTestCase): - def test_gpu_quant(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: - x = torch.randn(*x_shape).to(device) - m = nn.Sequential(nn.Linear(4, 16)).to(device) - y_ref = m(x) - _apply_weight_only_uint2_quant(m) - y_wo = m(x) +@pytest.mark.parametrize("input_shape", [[2,4], + [5,5,5,4], + [1,4,4]]) +def test_uint2_quant(input_shape): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.randn(*input_shape).to(device) + m = nn.Sequential(nn.Linear(4, 16)).to(device) + y_ref = m(x) + _apply_weight_only_uint2_quant(m) + y_wo = m(x) + y_compiled = torch.compile(m, fullgraph=True)(x) + # TODO: torch.allclose() WIP -if __name__ == "__main__": - main() + +if __name__ == '__main__': + test_uint2_quant([2,4]) \ No newline at end of file diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py new file mode 100644 index 0000000000..23621bc8bd --- /dev/null +++ b/torchao/prototype/dtypes/__init__.py @@ -0,0 +1,8 @@ + +from .uint2 import BitnetTensor, UInt2Tensor + +__all__ = [ + "BitnetTensor", + "UInt2Tensor", +] + diff --git a/torchao/prototype/dtypes/uintgen.py b/torchao/prototype/dtypes/uintgen.py index d16feada68..bade39e0c9 100644 --- a/torchao/prototype/dtypes/uintgen.py +++ b/torchao/prototype/dtypes/uintgen.py @@ -1,7 +1,7 @@ import torch """ -Contains generic functions to pack and unpack uint8 tensors into uint2, uint3, uint4, uint5, uint6, and uint7 tensors. +Contains generic functions to pack and unpack uintx (2-7) tensors into uint8 tensors. """ def down_size_uint2(size): From 1fdeb91038c24e06d1b9a5ccb8d2a49da88bdaf7 Mon Sep 17 00:00:00 2001 From: James Melvin Date: Sun, 16 Jun 2024 13:30:51 +0530 Subject: [PATCH 10/18] fix: implements pattern for uint2 and BitnetTensor --- test/dtypes/test_uint2.py | 12 +- torchao/prototype/dtypes/__init__.py | 3 +- torchao/prototype/dtypes/bitnet.py | 121 ++++++ torchao/prototype/dtypes/uint2.py | 615 +++++++++------------------ 4 files changed, 335 insertions(+), 416 deletions(-) create mode 100644 torchao/prototype/dtypes/bitnet.py diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index 906135ee1e..0e0ffb461a 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -15,9 +15,7 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -@pytest.mark.parametrize("input_shape", [[2,4], - [5,5,5,4], - [1,4,4]]) +@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) def test_uint2_quant(input_shape): device = 'cuda' if torch.cuda.is_available() else 'cpu' x = torch.randn(*input_shape).to(device) @@ -25,9 +23,9 @@ def test_uint2_quant(input_shape): y_ref = m(x) _apply_weight_only_uint2_quant(m) y_wo = m(x) - y_compiled = torch.compile(m, fullgraph=True)(x) - # TODO: torch.allclose() WIP - + assert y_ref.shape == y_wo.shape + # WIP - Need to use the latest build and test torch.compile + # y_compiled = torch.compile(m, fullgraph=True)(x) if __name__ == '__main__': - test_uint2_quant([2,4]) \ No newline at end of file + test_uint2_quant([2, 4]) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 23621bc8bd..9f16283ac5 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -1,5 +1,6 @@ -from .uint2 import BitnetTensor, UInt2Tensor +from .uint2 import UInt2Tensor +from .bitnet import BitnetTensor __all__ = [ "BitnetTensor", diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py new file mode 100644 index 0000000000..997f7458e2 --- /dev/null +++ b/torchao/prototype/dtypes/bitnet.py @@ -0,0 +1,121 @@ +import torch +from torchao.prototype.dtypes.uint2 import UInt2Tensor, unpack_uint2, pack_uint2 + +BITNET_OPS_TABLE = {} + +def implements(aten_ops): + def decorator(fn): + for op in aten_ops: + BITNET_OPS_TABLE[op] = fn + return fn + return decorator + +def _quantize_int2(x: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + # Quantize the input tensor to int2 + quant = x.sign() + 1 + + if target_dtype == torch.uint2: + quant = BitnetTensor.from_unpacked( + quant.to(torch.uint8), + ) + else: + quant = quant.to(target_dtype) + + return quant + +class BitnetTensor(UInt2Tensor): + def __new__(cls, input_tensor: torch.Tensor, **kwargs): + return super(BitnetTensor, cls).__new__(cls, input_tensor, **kwargs) + + def __init__(self, input_tensor: torch.Tensor, **kwargs): + super(BitnetTensor, self).__init__(input_tensor, **kwargs) + + @staticmethod + def __tensor_unflatten__(flattened, meta): + assert meta is None + elem = flattened["elem"] + return BitnetTensor(elem) + + @classmethod + def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": + return cls(pack_uint2(unpacked)) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + def allowed_subclasses(type): + return ( + issubclass(cls, type) or + issubclass(torch._subclasses.fake_tensor.FakeTensor, type) or + issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Bitnet, Up to the next one to handle") + + if func in BITNET_OPS_TABLE: + return BITNET_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError(f"Bitnet dispatch: attempting to run {func}, this is not supported") + + @classmethod + def from_float(cls, w: torch.Tensor): + w_int2 = _quantize_int2(w, torch.uint2).to(device=w.device) + return w_int2 + +@implements([torch.ops.aten.mm.default]) +def mm(func, args, kwargs): + x, weight = args + y = torch.mm(x, weight.to(torch.int8).to(x.device).to(x.dtype)) + return y + +@implements([torch.ops.aten.addmm.default]) +def addmm(func, args, kwargs): + bias, x, weight = args + y = torch.addmm(bias, x, weight.to(torch.int8).to(x.device).to(x.dtype)) + if bias is not None: + y += bias + return y + +@implements([torch.ops.aten.t.default]) +def t(func, args, kwargs): + (tensor,) = args + unpacked = unpack_uint2(tensor.elem).to(tensor.device) + transposed = unpacked.t() + return BitnetTensor(pack_uint2(transposed)) + +@implements([torch.ops.aten.detach.default]) +def detach(func, args, kwargs): + (tensor,) = args + return tensor + +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, args, kwargs): + (tensor, dtype) = args + if dtype == torch.int8: + return unpack_uint2(tensor.elem).view(torch.uint8) - 1 + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return unpack_uint2(tensor.elem).to(torch.int8).to(dtype) + elif dtype == torch.uint8: + return unpack_uint2(tensor.elem).view(torch.uint8) + elif dtype == torch.uint2: + return tensor.elem + raise NotImplementedError(f"to {dtype} not supported") + +@implements([torch.ops.aten._to_copy.default]) +def _to_copy(func, args, kwargs): + (tensor,) = args + dtype = kwargs["dtype"] + if dtype == torch.int8: + return BitnetTensor(unpack_uint2(tensor).view(tensor.shape).view(torch.int8) - 1) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return BitnetTensor(tensor.to(torch.int8).to(dtype)) + elif dtype == torch.uint2: + return BitnetTensor(tensor) + raise NotImplementedError(f"to {dtype} not supported") + +if __name__ == "__main__": + # Test case using BitnetTensor + a = torch.randint(0, 15, (2, 8), dtype=torch.uint8) + a_bitnet = BitnetTensor(a) + a_bitnet = a_bitnet.to(torch.uint2) + print(f"a_bitnet: {a_bitnet}") + diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index c1f11d9ba2..d661595cb5 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -1,151 +1,11 @@ import torch import torch._prims_common as utils -import torch.utils._pytree as pytree -from torch.library import impl, Library -from ...dtypes.uint4 import qtensor_lib - -""" -Converts a tensor of uint8 to a tensor of uint2 mostly applicable for bitnet 1.58 -""" - -def down_size(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4) - -def up_size(size): - return (*size[:-1], size[-1] * 4) - -if torch.cuda.is_available() and torch.utils._triton.has_triton(): - import triton - import triton.language as tl - - @triton.jit - def triton_unpack_uint8_to_trinary2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - uint8_value = tl.load(uint8_data + offsets, mask=mask) - - first_elements = ((uint8_value >> 6) & 0b11).to(tl.int8) - 1 - second_elements = ((uint8_value >> 4) & 0b11).to(tl.int8) - 1 - third_elements = ((uint8_value >> 2) & 0b11).to(tl.int8) - 1 - fourth_elements = (uint8_value & 0b11).to(tl.int8) - 1 - - tl.store(output + offsets * 4 + 0, first_elements, mask=mask) - tl.store(output + offsets * 4 + 1, second_elements, mask=mask) - tl.store(output + offsets * 4 + 2, third_elements, mask=mask) - tl.store(output + offsets * 4 + 3, fourth_elements, mask=mask) - - def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: - uint8_data = uint8_data.to('cuda') - shape = uint8_data.shape - output = torch.empty(up_size(shape), dtype=torch.int8, device='cuda') - n_elements = uint8_data.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - triton_unpack_uint8_to_trinary2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) - return output - - @triton.jit - def triton_unpack_uint2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - uint8_value = tl.load(uint8_data + offsets, mask=mask) - - first_elements = (uint8_value >> 6) & 0b11 - second_elements = (uint8_value >> 4) & 0b11 - third_elements = (uint8_value >> 2) & 0b11 - fourth_elements = uint8_value & 0b11 - - tl.store(output + offsets * 4 + 0, first_elements, mask=mask) - tl.store(output + offsets * 4 + 1, second_elements, mask=mask) - tl.store(output + offsets * 4 + 2, third_elements, mask=mask) - tl.store(output + offsets * 4 + 3, fourth_elements, mask=mask) - - def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - uint8_data = uint8_data.to('cuda') - shape = uint8_data.shape - output = torch.empty(up_size(shape), dtype=torch.uint8, device='cuda') - n_elements = uint8_data.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - triton_unpack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) - return output - - @triton.jit - def triton_pack_uint2(uint8_data, output, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets * 4 < n_elements - - first_elements = tl.load(uint8_data + offsets * 4 + 0, mask=mask) - second_elements = tl.load(uint8_data + offsets * 4 + 1, mask=mask) - third_elements = tl.load(uint8_data + offsets * 4 + 2, mask=mask) - fourth_elements = tl.load(uint8_data + offsets * 4 + 3, mask=mask) - - packed_data = (first_elements << 6) | (second_elements << 4) | (third_elements << 2) | fourth_elements - - tl.store(output + offsets, packed_data, mask=mask) - - def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - uint8_data = uint8_data.to('cuda') - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - n_elements = uint8_data.numel() - packed_shape = down_size(shape) - output = torch.empty(packed_shape, dtype=torch.uint8, device='cuda') - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE'] * 4),) - triton_pack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) - return output - -else: - # TODO: torch compile issue https://github.com/pytorch/pytorch/issues/127374 is fixed - #@torch.compile - def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor: - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - first_elements = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1 - second_elements = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1 - third_elements = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1 - fourth_elements = (uint8_data & 0b11).to(torch.int8) - 1 - return torch.stack([first_elements, second_elements, third_elements, fourth_elements], dim=-1).view(up_size(shape)) - - #@torch.compile - def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = ((uint8_data >> 6) & 0b11) - second_elements = ((uint8_data >> 4) & 0b11) - third_elements = ((uint8_data >> 2) & 0b11) - fourth_elements = (uint8_data & 0b11) - return torch.stack((first_elements, second_elements, third_elements, fourth_elements), dim=-1).view(up_size(shape)) - - #packing uint8 - #@torch.compile - def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) - return packed_data +from dataclasses import dataclass +from typing import Dict, Any, Tuple +UINT2_OPS_TABLE: Dict[Any, Any] = {} def fill_defaults(args, n, defaults_tail): - """ - __torch_dispatch__ doesn't guarantee the number of arguments you are - passed (e.g., defaulted arguments are not passed); but usually it is - convenient to pad out the arguments list with defaults. This function - helps you do that. - Args: - args: the list of positional arguments passed to __torch_dispatch__ - n: the number of arguments you are expecting to get - defaults_tail: default values for the arguments, starting from the - end of the list - Example: - >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) - [1, 2, 3, 4, 5] - >>> fill_defaults([1, 2, 3], 5, [None, None, None]) - [1, 2, 3, None, None]] - """ if n - len(defaults_tail) > len(args): raise RuntimeError("not enough defaults to fill arguments") r = list(args) @@ -153,294 +13,233 @@ def fill_defaults(args, n, defaults_tail): r.append(defaults_tail[i - n + len(defaults_tail)]) return r +def implements(aten_ops): + def decorator(fn): + for op in aten_ops: + UINT2_OPS_TABLE[op] = fn + return fn + return decorator -# qtensor_lib = Library("qtensors", "DEF") -qtensor_lib.define( - "quantize_per_tensor_uint2(Tensor input, float scale, int zero_point) -> Tensor" -) - - -@impl(qtensor_lib, "quantize_per_tensor_uint2", "CompositeExplicitAutograd") -def quantize_per_tensor_uint2( - input: torch.Tensor, - scale: float = 1.0, - zero_point: int = 1, -) -> torch.Tensor: - inv_scale = 1.0 / scale - return pack_uint2( - torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 2).to(torch.uint8) - ) - - -qtensor_lib.define( - "dequantize_per_tensor_uint2(Tensor input, float scale, int zero_point) -> Tensor" -) - +def down_size(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by 4" + return (*size[:-1], size[-1] // 4) -@impl(qtensor_lib, "dequantize_per_tensor_uint2", "CompositeExplicitAutograd") -def dequantize_per_tensor_uint2( - input: torch.Tensor, - scale: float = 1.0, - zero_point: int = 1, -) -> torch.Tensor: - input = unpack_uint2(input) - return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale +def up_size(size): + return (*size[:-1], size[-1] * 4) +def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = ((uint8_data >> 6) & 0b11) + second_elements = ((uint8_data >> 4) & 0b11) + third_elements = ((uint8_data >> 2) & 0b11) + fourth_elements = (uint8_data & 0b11) + return torch.stack((first_elements, second_elements, third_elements, fourth_elements), dim=-1).view(up_size(shape)) + +def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 4 == 0, f"{shape}, last dim not divisible by 4" + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data + +@dataclass +class SubclassTensorArgs: + original_shape: torch.Size + original_strides: Tuple + storage_offset: int + dtype: torch.dtype + device: torch.device + requires_grad: bool class UInt2Tensor(torch.Tensor): - @staticmethod - def __new__(cls, elem, **kwargs): - assert elem.dtype is torch.uint8 - assert not kwargs.get("requires_grad", False) - kwargs["requires_grad"] = False - - return torch.Tensor._make_wrapper_subclass( - cls, up_size(elem.shape), dtype=torch.uint2, **kwargs + def __new__(cls, input_tensor: torch.Tensor): + assert input_tensor.dtype == torch.uint8 + tensor_meta = SubclassTensorArgs( + input_tensor.size(), + input_tensor.stride(), + input_tensor.storage_offset(), + torch.uint2, + input_tensor.device, + input_tensor.requires_grad ) - - def __init__(self, elem, **kwargs): - self.elem = elem + uint2i_tensor = torch.Tensor._make_wrapper_subclass( + cls, + up_size(tensor_meta.original_shape), + tensor_meta.original_strides, + tensor_meta.storage_offset, + dtype=tensor_meta.dtype, + device=tensor_meta.device, + requires_grad=tensor_meta.requires_grad + ) + return uint2i_tensor + + def __init__(self, input_tensor: torch.Tensor, **kwargs): + self.elem = input_tensor @classmethod - def from_unpacked(cls, unpacked): + def from_packed(cls, unpacked): return UInt2Tensor(pack_uint2(unpacked)) - + def tolist(self): - return self.to(torch.uint8).tolist() - + return unpack_uint2(self.elem).tolist() + def __tensor_flatten__(self): return ["elem"], None - + @staticmethod - def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + def __tensor_unflatten__(flattened, meta): assert meta is None elem = flattened["elem"] return UInt2Tensor(elem) - + def __hash__(self): return hash(self.elem) - + def __eq__(self, other): return torch.equal(self.elem, other.elem) + def __repr__(self): + data = unpack_uint2(self.elem).tolist() + return f"UInt2Tensor({data}, dtype=torch.uint2)" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func is torch.ops.aten.view.default: - self, size = args - size = utils.infer_size(size, self.numel()) - assert not kwargs - # WARNING: views not preserved - return UInt2Tensor(self.elem.reshape(down_size(size))) - elif func is torch.ops.aten.view.dtype: - self, dtype = args - if dtype == torch.uint8: - return unpack_uint2(self.elem).view(torch.uint8) - return NotImplementedError(f"view {args}") - elif func is torch.ops.aten.to.dtype: - self, dtype = args - if dtype == torch.uint8: - return unpack_uint2(self.elem).view(torch.uint8) - return NotImplementedError(f"to {args}") - elif func is torch.ops.aten.eq.Tensor: - args = pytree.tree_map_only( - UInt2Tensor, lambda x: x.elem.view(torch.uint8), args - ) - kwargs = pytree.tree_map_only( - UInt2Tensor, lambda x: x.elem.view(torch.uint8), kwargs - ) - return torch.ops.aten.eq.Tensor(*args, **kwargs) - elif func is torch.ops.aten._to_copy.default: - (self,) = args - dtype = kwargs["dtype"] - if dtype == torch.uint8: - return unpack_uint2(self.elem).view(self.shape) - if dtype in (torch.uint16, torch.uint32, torch.uint64): - return self.to(torch.uint8).to(dtype) - if dtype == torch.uint2: - return self - else: - raise NotImplementedError(f"_to_copy {kwargs}") - elif func is torch.ops.aten.unbind.int: - # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to - # create four tensors containing one element each. But we can't - # do this with uint2 because such a tensor's size is not divisible - # by bytes. What I am going to do instead is promote to uint8 - # when this happens - self, dim = fill_defaults(args, 2, [0]) - if dim != self.dim() - 1: - raise NotImplementedError(f"unbind dim={dim}") - else: - # We're unbinding the last dimension, need to promote - return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind( - dim - ) - elif func is torch.ops.aten.select.int: - self, dim, index = args - if dim != self.dim() - 1: - return UInt2Tensor(torch.ops.aten.select.int(self.elem, dim, index)) - else: - raise NotImplementedError(f"select dim={dim}") - elif func is torch.ops.aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == self.dim() - 1: - # hard case - if step != 1: - raise NotImplementedError(f"slice step={step}") - assert start % 4 == 0, start - assert end >= self.shape[dim] or end % 4 == 0, end - return UInt2Tensor( - torch.ops.aten.slice.Tensor(self.elem, dim, start // 4, end // 4, 1) - ) - else: - # easy case - return UInt2Tensor( - torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step) - ) - elif func is torch.ops.aten.t.default: - (self,) = args - unpacked = unpack_uint2(self.elem) - transposed = torch.ops.aten.t.default(unpacked) - transposed_and_packed = pack_uint2(transposed) - return UInt2Tensor(transposed_and_packed) - elif func is torch.ops.aten.transpose_copy.int: - self, dim0, dim1 = args - unpacked = unpack_uint2(self.elem).view(self.shape) - transposed = torch.ops.aten.transpose_copy.int(unpacked, dim0, dim1) - transposed_and_packed = pack_uint2(transposed) - return UInt2Tensor(transposed_and_packed) - elif func is torch.ops.aten.as_strided.default: - # size, stride, storage_offset are referring to tensor elements, not physical bytes - self, size, stride, storage_offset = args - size = down_size(size) - new_stride = [] - for s in stride: - if s != 1: - # since two int4 equals to 1 uint8 - new_stride.append(s // 4) - else: - new_stride.append(s) - stride = new_stride - - storage_offset //= 4 - return UInt2Tensor( - torch.ops.aten.as_strided.default( - self.elem, size, stride, storage_offset - ) + def allowed_subclasses(type): + return ( + issubclass(cls, type) or + issubclass(torch._subclasses.fake_tensor.FakeTensor, type) or + issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) ) - ## /clamp_/round <- values - elif func is torch.ops.aten.min: - self, dim, keepdim = fill_defaults(args, 0, False) - unpacked = unpack_uint2(self.elem).view(self.shape) - min_result = torch.ops.aten.min(unpacked, dim, keepdim) - return min_result - elif func is torch.ops.aten.max: - self, dim, keepdim = fill_defaults(args, 0, False) - unpacked = unpack_uint2(self.elem).view(self.shape) - max_result = torch.ops.aten.max(unpacked, dim, keepdim) - return max_result - elif func is torch.ops.aten.amin: - self, dim, keepdim = fill_defaults(args, 0, False) - unpacked = unpack_uint2(self.elem).view(self.shape) - min_result = torch.ops.aten.amin(unpacked, dim, keepdim) - return min_result - elif func is torch.ops.aten.amax: - self, dim, keepdim = fill_defaults(args, 0, False) - unpacked = unpack_uint2(self.elem).view(self.shape) - max_result = torch.ops.aten.amax(unpacked, dim, keepdim) - return max_result - elif func is torch.ops.aten.clamp: - self, min_v, max_v = fill_defaults(args, None, None) - unpacked = unpack_uint2(self.elem).view(self.shape) - clamped_result = torch.ops.aten.clamp(unpacked, min_v, max_v) - return clamped_result - elif func is torch.ops.aten.abs: - self = args - unpacked = unpack_uint2(self.elem).view(self.shape) - abs_result = torch.ops.aten.abs(unpacked) - return abs_result - else: - raise NotImplementedError(f"{func}") - - __torch_function__ = torch._C._disabled_torch_function_impl - - -def _quantize_int2(x: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: - # should this just be addition or .sign? Is the .sign a trick here or could this be removed? - quant = x.sign() + 1 - - if target_dtype == torch.uint2: - quant = BitnetTensor.from_unpacked( - quant.to(torch.uint8), - ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Up to the next one to handle") + + if func in UINT2_OPS_TABLE: + return UINT2_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError(f"UINT2 dispatch: attempting to run {func}, this is not supported") + +@implements([torch.ops.aten.view.default]) +def uint2_view(func, args, kwargs): + tensor, size = args + size = utils.infer_size(size, tensor.numel()) + assert not kwargs + dsize = down_size(size) + reshaped_elem = tensor.elem.view(dsize) + return UInt2Tensor(reshaped_elem) + +@implements([torch.ops.aten.view.dtype]) +def view_dtype(func, args, kwargs): + tensor, dtype = args + if dtype is torch.uint8: + return unpack_uint2(tensor.elem).to(torch.uint8) + raise NotImplementedError(f"view {dtype} not supported") + +@implements([torch.ops.aten.clone.default]) +def clone(func, args, kwargs): + tensor = args[0] + return UInt2Tensor(tensor.elem.clone()) + +@implements([torch.ops.aten._unsafe_view.default]) +def unsafe_view(func, args, kwargs): + tensor, size = args + size = utils.infer_size(size, tensor.numel()) + assert not kwargs + dsize = down_size(size) + reshaped_elem = tensor.elem.view(dsize) + return UInt2Tensor(reshaped_elem) + +@implements([torch.ops.aten.unbind.int]) +def unbind(func, args, kwargs): + tensor, dim = fill_defaults(args, 2, [0]) + if dim != tensor.dim() - 1: + raise NotImplementedError(f"unbind dim={dim}") else: - quant = quant.to(target_dtype) - - return quant - - -class BitnetTensor(UInt2Tensor): - @staticmethod - def __new__(cls, elem, **kwargs): - return super().__new__(cls, elem, **kwargs) - - def __init__(self, elem, **kwargs): - super().__init__(elem, **kwargs) - - def __tensor_flatten__(self): - return ["elem"], None - - @staticmethod - def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): - assert meta is None - elem = flattened["elem"] - return BitnetTensor(elem) - - @classmethod - # inconsistently. - def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": - return cls(pack_uint2(unpacked)) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func is torch.ops.aten.mm.default: - x, weight = args - y = torch.mm(x, weight.to(torch.int8).to(x.device).to(x.dtype)) - return y - elif func is torch.ops.aten.addmm.default: - bias, x, weight = args - y = torch.mm(x, weight.to(torch.int8).to(x.device).to(x.dtype)) - if bias is not None: - y += bias - return y - elif func is torch.ops.aten.t.default: - # TODO: add proper support for transpose - (self,) = args - unpacked = unpack_uint2(self.elem).to(self.device) - transposed = torch.ops.aten.t.default(unpacked) - return BitnetTensor.from_unpacked(transposed) - elif func is torch.ops.aten.detach.default: - (self,) = args - return self - elif func is torch.ops.aten.to.dtype: - self, dtype = args - if dtype == torch.int8: - return unpack_uint2(self.elem).view(torch.int8) - 1 - elif func is torch.ops.aten._to_copy.default: - (self,) = args - dtype = kwargs["dtype"] - if dtype == torch.int8: - return unpack_uint2(self.elem).view(self.shape).view(torch.int8) - 1 - elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): - return self.to(torch.int8).to(dtype) - elif dtype == torch.uint2: - return self - return super().__torch_dispatch__(func, types, args, kwargs) - - @classmethod - def from_float(cls, w: torch.Tensor): - w_int2 = _quantize_int2( - w, torch.uint2 - ).to(device=w.device) - return w_int2 - + x = tensor.elem.to(torch.uint8).unbind(dim) + return x + +@implements([torch.ops.aten._to_copy.default]) +def to_copy(func, args, kwargs): + (tensor,) = args + dtype = kwargs["dtype"] + if dtype == torch.uint8: + return unpack_uint2(tensor.elem).view(tensor.shape).view(torch.uint8) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return tensor.to(torch.uint8).to(dtype) + elif dtype == torch.uint2: + return tensor + raise NotImplementedError(f"to_copy {dtype} not supported") + +@implements([torch.ops.aten.select.int]) +def select(func, args, kwargs): + tensor, dim, index = args + if dim != tensor.dim() - 1: + selected_elem = tensor.elem.select(dim, index) + return UInt2Tensor(selected_elem) + else: + raise NotImplementedError(f"select dim={dim}") + +@implements([torch.ops.aten.reshape.default]) +def reshape(func, args, kwargs): + tensor, size = args + size = utils.infer_size(size, tensor.numel()) + assert not kwargs + dsize = down_size(size) + reshaped_elem = tensor.elem.view(dsize) + return UInt2Tensor(reshaped_elem) + +def slice_tensor(func, args, kwargs): + tensor, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == tensor.dim() - 1: + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % 4 == 0, start + assert end is None or end % 4 == 0, end + end = end if end is not None else tensor.shape[dim] + sliced_elem = tensor.elem[..., start // 4 : end // 4 : step] + return UInt2Tensor(sliced_elem) + else: + sliced_elem = tensor.elem[..., start:end:step] + return UInt2Tensor(sliced_elem) + +@implements([torch.ops.aten.equal.default]) +def equal(func, args, kwargs): + tensor, other = args + return torch.equal(tensor.elem, other.elem) + +@implements([torch.ops.aten.detach.default]) +def detach(func, args, kwargs): + (tensor,) = args + return tensor.elem.detach() + +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, args, kwargs): + (tensor, dtype) = args + if dtype == torch.uint8: + return tensor.elem.view(tensor.shape).view(torch.uint8) + raise NotImplementedError(f"to {dtype} not supported") + + +if __name__ == "__main__": + import torch.nn as nn + uint8_data = torch.randint(0, 4, (2, 8), dtype=torch.uint8) + print(f"uint8_data: {uint8_data}") + uint2_data = UInt2Tensor(uint8_data) + print(f"uint2_data: {uint2_data}") + + x = UInt2Tensor(torch.tensor([ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], dtype=torch.uint8)) + print(f"x: {x}") + + input_shapes = [[2, 4], [5, 5, 5, 4], [1, 4, 4]] + for input_shape in input_shapes: + x = torch.randn(*input_shape) + m = nn.Sequential(nn.Linear(4, 16)) + y_ref = m(x) + y_wo = m(x) + y_compiled = torch.compile(m, fullgraph=True)(x) + From 4eb56797921d02fd27cac7e7f655e5d74ac4a33d Mon Sep 17 00:00:00 2001 From: James Melvin Date: Sun, 16 Jun 2024 19:38:47 +0530 Subject: [PATCH 11/18] fix: torch.uint2 available after torch 2.3 --- test/dtypes/test_uint2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index 0e0ffb461a..d9c474267d 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -3,6 +3,10 @@ import torch.nn as nn from torchao.prototype.dtypes import BitnetTensor from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.utils import TORCH_VERSION_AFTER_2_3 + +if not TORCH_VERSION_AFTER_2_3: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) def _apply_weight_only_uint2_quant(model): def fn(mod): From 666a724ad89eb6cc01031f673f720d8c9b852d98 Mon Sep 17 00:00:00 2001 From: James Melvin Date: Mon, 17 Jun 2024 12:48:25 +0530 Subject: [PATCH 12/18] fix: test cases for BitnetTensor, UInt2Tensor and bitpacking gen --- test/dtypes/test_bitnet.py | 33 ++++++++++++++++++ test/dtypes/test_uint2.py | 47 ++++++++++++------------- test/prototype/test_bitpacking_gen.py | 26 ++++++++++++++ torchao/prototype/dtypes/bitnet.py | 49 +++++++++++++++++---------- torchao/prototype/dtypes/uint2.py | 36 +++++++------------- torchao/prototype/dtypes/uintgen.py | 18 ---------- 6 files changed, 124 insertions(+), 85 deletions(-) create mode 100644 test/dtypes/test_bitnet.py create mode 100644 test/prototype/test_bitpacking_gen.py diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py new file mode 100644 index 0000000000..f8a5e54bbf --- /dev/null +++ b/test/dtypes/test_bitnet.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torchao.prototype.dtypes import BitnetTensor +from torchao.prototype.dtypes.uint2 import unpack_uint2 + +@pytest.fixture +def bitnet_tensor(): + input_tensor = torch.randint(0, 15, (4,4), dtype=torch.uint8) + return BitnetTensor.from_unpacked(input_tensor) + +def test_copy(bitnet_tensor): + copied_tensor = bitnet_tensor.clone() + assert torch.equal(bitnet_tensor.elem, copied_tensor.elem) + +def test_transpose(bitnet_tensor): + transposed_tensor = bitnet_tensor.t() + expected_tensor = unpack_uint2(bitnet_tensor.elem).t() + assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) + +def test_multiply(bitnet_tensor): + w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8) + w = BitnetTensor.from_unpacked(w_t) + y = torch.addmm(torch.Tensor([1]), bitnet_tensor, w) + +@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) +def test_conversion(bitnet_tensor, dtype): + converted_tensor = bitnet_tensor.to(dtype) + expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) + assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) + +if __name__ == "__main__": + pytest.main() + diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index d9c474267d..c4ff28a688 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -1,35 +1,30 @@ import pytest import torch import torch.nn as nn -from torchao.prototype.dtypes import BitnetTensor -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.prototype.dtypes import UInt2Tensor +from torchao.prototype.dtypes.uint2 import unpack_uint2 -if not TORCH_VERSION_AFTER_2_3: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) +@pytest.fixture +def uint2_tensor(): + input_tensor = torch.randint(0, 15, (4,4), dtype = torch.uint8) + return UInt2Tensor(input_tensor) -def _apply_weight_only_uint2_quant(model): - def fn(mod): - mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False) - return mod +def test_copy(uint2_tensor): + copied_tensor = uint2_tensor.clone() + assert torch.equal(uint2_tensor.elem, copied_tensor.elem) - _replace_with_custom_fn_if_matches_filter( - model, - lambda mod: fn(mod), - lambda mod, fqn: isinstance(mod, torch.nn.Linear), - ) +def test_transpose(uint2_tensor): + transposed_tensor = uint2_tensor.t() + expected_tensor = unpack_uint2(uint2_tensor.elem).t() + assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) -@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) -def test_uint2_quant(input_shape): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - x = torch.randn(*input_shape).to(device) - m = nn.Sequential(nn.Linear(4, 16)).to(device) - y_ref = m(x) - _apply_weight_only_uint2_quant(m) - y_wo = m(x) - assert y_ref.shape == y_wo.shape - # WIP - Need to use the latest build and test torch.compile - # y_compiled = torch.compile(m, fullgraph=True)(x) +@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) +def test_conversion(uint2_tensor, dtype): + converted_tensor = uint2_tensor.to(dtype) + expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype) + assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) if __name__ == '__main__': - test_uint2_quant([2, 4]) + pytest.main(__file__) + + diff --git a/test/prototype/test_bitpacking_gen.py b/test/prototype/test_bitpacking_gen.py new file mode 100644 index 0000000000..b729c6250d --- /dev/null +++ b/test/prototype/test_bitpacking_gen.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from torchao.prototype.dtypes.uintgen import ( + pack_uint2, unpack_uint2, pack_uint3, unpack_uint3, pack_uint4, unpack_uint4, + pack_uint5, unpack_uint5, pack_uint6, unpack_uint6, pack_uint7, unpack_uint7 +) + +@pytest.mark.parametrize("pack_fn, unpack_fn, bit_count", [ + (pack_uint2, unpack_uint2, 2), + (pack_uint3, unpack_uint3, 3), + (pack_uint4, unpack_uint4, 4), + (pack_uint5, unpack_uint5, 5), + (pack_uint6, unpack_uint6, 6), + (pack_uint7, unpack_uint7, 7), +]) +def test_uint_packing(pack_fn, unpack_fn, bit_count): + x = torch.arange(0, 256, dtype=torch.uint8) + y = pack_fn(x) + z = unpack_fn(y) + k = z.view(-1, 2 ** bit_count) + check = torch.arange(0, 2 ** bit_count, dtype=torch.uint8).repeat(k.size(0), 1) + assert torch.all(k == check), f"Failed for {bit_count}-bit packing" + +if __name__ == "__main__": + pytest.main(__file__) \ No newline at end of file diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py index 997f7458e2..cdbc0656df 100644 --- a/torchao/prototype/dtypes/bitnet.py +++ b/torchao/prototype/dtypes/bitnet.py @@ -13,14 +13,7 @@ def decorator(fn): def _quantize_int2(x: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: # Quantize the input tensor to int2 quant = x.sign() + 1 - - if target_dtype == torch.uint2: - quant = BitnetTensor.from_unpacked( - quant.to(torch.uint8), - ) - else: - quant = quant.to(target_dtype) - + quant = BitnetTensor.from_unpacked(quant.to(torch.uint8)) return quant class BitnetTensor(UInt2Tensor): @@ -60,19 +53,43 @@ def allowed_subclasses(type): def from_float(cls, w: torch.Tensor): w_int2 = _quantize_int2(w, torch.uint2).to(device=w.device) return w_int2 + + def clone(self): + return BitnetTensor(self.elem.clone()) + + def copy_(self, src): + self.elem.copy_(src.elem) + return self + + def to(self, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], torch.dtype): + dtype = args[0] + if dtype == torch.int8: + return unpack_uint2(self.elem).view(self.shape).view(torch.int8) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return unpack_uint2(self.elem).to(torch.int8).to(dtype) + elif dtype == torch.uint8: + return unpack_uint2(self.elem).view(torch.uint8) + elif dtype == torch.uint2: + return self + return super().to(*args, **kwargs) @implements([torch.ops.aten.mm.default]) def mm(func, args, kwargs): x, weight = args - y = torch.mm(x, weight.to(torch.int8).to(x.device).to(x.dtype)) + x = unpack_uint2(x.elem).to(torch.float32) + weight = unpack_uint2(weight.elem).to(torch.float32) + y = torch.mm(x, weight) return y @implements([torch.ops.aten.addmm.default]) def addmm(func, args, kwargs): bias, x, weight = args - y = torch.addmm(bias, x, weight.to(torch.int8).to(x.device).to(x.dtype)) + x = unpack_uint2(x.elem).to(torch.float32) + weight = unpack_uint2(weight.elem).to(torch.float32) if bias is not None: - y += bias + bias = bias.to(torch.float32) + y = torch.addmm(bias, x, weight) return y @implements([torch.ops.aten.t.default]) @@ -112,10 +129,8 @@ def _to_copy(func, args, kwargs): return BitnetTensor(tensor) raise NotImplementedError(f"to {dtype} not supported") -if __name__ == "__main__": - # Test case using BitnetTensor - a = torch.randint(0, 15, (2, 8), dtype=torch.uint8) - a_bitnet = BitnetTensor(a) - a_bitnet = a_bitnet.to(torch.uint2) - print(f"a_bitnet: {a_bitnet}") +@implements([torch.ops.aten.clone.default]) +def clone(func, args, kwargs): + (tensor,) = args + return tensor.clone() diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index d661595cb5..c8d2746545 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -217,29 +217,17 @@ def detach(func, args, kwargs): def to_dtype(func, args, kwargs): (tensor, dtype) = args if dtype == torch.uint8: - return tensor.elem.view(tensor.shape).view(torch.uint8) - raise NotImplementedError(f"to {dtype} not supported") + return unpack_uint2(tensor.elem).view(torch.uint8) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return unpack_uint2(tensor.elem).to(torch.uint8).to(dtype) + elif dtype == torch.uint2: + return tensor.elem + raise NotImplementedError(f"to {dtype} not supported") -if __name__ == "__main__": - import torch.nn as nn - uint8_data = torch.randint(0, 4, (2, 8), dtype=torch.uint8) - print(f"uint8_data: {uint8_data}") - uint2_data = UInt2Tensor(uint8_data) - print(f"uint2_data: {uint2_data}") - - x = UInt2Tensor(torch.tensor([ - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) - print(f"x: {x}") - - input_shapes = [[2, 4], [5, 5, 5, 4], [1, 4, 4]] - for input_shape in input_shapes: - x = torch.randn(*input_shape) - m = nn.Sequential(nn.Linear(4, 16)) - y_ref = m(x) - y_wo = m(x) - y_compiled = torch.compile(m, fullgraph=True)(x) - +@implements([torch.ops.aten.t.default]) +def t(func, args, kwargs): + (tensor,) = args + unpacked = unpack_uint2(tensor.elem).to(tensor.device) + transposed = unpacked.t() + return UInt2Tensor(pack_uint2(transposed)) diff --git a/torchao/prototype/dtypes/uintgen.py b/torchao/prototype/dtypes/uintgen.py index bade39e0c9..1312816f1e 100644 --- a/torchao/prototype/dtypes/uintgen.py +++ b/torchao/prototype/dtypes/uintgen.py @@ -342,21 +342,3 @@ def unpack_uint7(packed_data: torch.Tensor) -> torch.Tensor: ).view(up_size_uint7(shape)) return unpacked_data - - -def test_uint_small_range(pack_fn, unpack_fn, bit_count): - x = torch.arange(0, 256, dtype=torch.uint8) - y = pack_fn(x) - z = unpack_fn(y) - k = z.view(-1, 2 ** bit_count) - check = torch.arange(0, 2 ** bit_count, dtype=torch.uint8).repeat(k.size(0), 1) - assert torch.all(k == check) - - -if __name__ == "__main__": - test_uint_small_range(pack_uint2, unpack_uint2, 2) - test_uint_small_range(pack_uint3, unpack_uint3, 3) - test_uint_small_range(pack_uint4, unpack_uint4, 4) - test_uint_small_range(pack_uint5, unpack_uint5, 5) - test_uint_small_range(pack_uint6, unpack_uint6, 6) - test_uint_small_range(pack_uint7, unpack_uint7, 7) From a2a435937702e65d41ff0c0559b410806cdc11d3 Mon Sep 17 00:00:00 2001 From: James Melvin Date: Mon, 17 Jun 2024 14:05:40 +0530 Subject: [PATCH 13/18] fix: removed torch.uint2 --- torchao/prototype/dtypes/bitnet.py | 11 ++++++----- torchao/prototype/dtypes/uint2.py | 8 ++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py index cdbc0656df..f48d8df5ea 100644 --- a/torchao/prototype/dtypes/bitnet.py +++ b/torchao/prototype/dtypes/bitnet.py @@ -10,7 +10,7 @@ def decorator(fn): return fn return decorator -def _quantize_int2(x: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: +def _quantize_int2(x: torch.Tensor) -> torch.Tensor: # Quantize the input tensor to int2 quant = x.sign() + 1 quant = BitnetTensor.from_unpacked(quant.to(torch.uint8)) @@ -51,7 +51,7 @@ def allowed_subclasses(type): @classmethod def from_float(cls, w: torch.Tensor): - w_int2 = _quantize_int2(w, torch.uint2).to(device=w.device) + w_int2 = _quantize_int2(w).to(device=w.device) return w_int2 def clone(self): @@ -62,6 +62,7 @@ def copy_(self, src): return self def to(self, *args, **kwargs): + (tensor,) = args if len(args) == 1 and isinstance(args[0], torch.dtype): dtype = args[0] if dtype == torch.int8: @@ -70,7 +71,7 @@ def to(self, *args, **kwargs): return unpack_uint2(self.elem).to(torch.int8).to(dtype) elif dtype == torch.uint8: return unpack_uint2(self.elem).view(torch.uint8) - elif dtype == torch.uint2: + elif isinstance(tensor, BitnetTensor): return self return super().to(*args, **kwargs) @@ -113,7 +114,7 @@ def to_dtype(func, args, kwargs): return unpack_uint2(tensor.elem).to(torch.int8).to(dtype) elif dtype == torch.uint8: return unpack_uint2(tensor.elem).view(torch.uint8) - elif dtype == torch.uint2: + elif isinstance(tensor, BitnetTensor): return tensor.elem raise NotImplementedError(f"to {dtype} not supported") @@ -125,7 +126,7 @@ def _to_copy(func, args, kwargs): return BitnetTensor(unpack_uint2(tensor).view(tensor.shape).view(torch.int8) - 1) elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): return BitnetTensor(tensor.to(torch.int8).to(dtype)) - elif dtype == torch.uint2: + elif isinstance(tensor, BitnetTensor): return BitnetTensor(tensor) raise NotImplementedError(f"to {dtype} not supported") diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index c8d2746545..e071e822d9 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -59,7 +59,7 @@ def __new__(cls, input_tensor: torch.Tensor): input_tensor.size(), input_tensor.stride(), input_tensor.storage_offset(), - torch.uint2, + cls, input_tensor.device, input_tensor.requires_grad ) @@ -68,7 +68,7 @@ def __new__(cls, input_tensor: torch.Tensor): up_size(tensor_meta.original_shape), tensor_meta.original_strides, tensor_meta.storage_offset, - dtype=tensor_meta.dtype, + dtype=torch.uint8, #Not sure if this is correct device=tensor_meta.device, requires_grad=tensor_meta.requires_grad ) @@ -167,7 +167,7 @@ def to_copy(func, args, kwargs): return unpack_uint2(tensor.elem).view(tensor.shape).view(torch.uint8) elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): return tensor.to(torch.uint8).to(dtype) - elif dtype == torch.uint2: + elif isinstance(tensor, UInt2Tensor): return tensor raise NotImplementedError(f"to_copy {dtype} not supported") @@ -220,7 +220,7 @@ def to_dtype(func, args, kwargs): return unpack_uint2(tensor.elem).view(torch.uint8) elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): return unpack_uint2(tensor.elem).to(torch.uint8).to(dtype) - elif dtype == torch.uint2: + elif isinstance(tensor, UInt2Tensor): return tensor.elem raise NotImplementedError(f"to {dtype} not supported") From 60970c510e9466edc77a00390e1a514326579f83 Mon Sep 17 00:00:00 2001 From: James Melvin Date: Mon, 17 Jun 2024 21:59:30 +0530 Subject: [PATCH 14/18] fix: wrap detach in UIntTensor, torch.compile test --- test/dtypes/test_bitnet.py | 27 ++++++++++++++++++++++++++- torchao/prototype/dtypes/bitnet.py | 26 +++++++++++++++++--------- torchao/prototype/dtypes/uint2.py | 3 ++- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index f8a5e54bbf..b0407143b2 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -1,7 +1,9 @@ import pytest import torch +import torch.nn as nn from torchao.prototype.dtypes import BitnetTensor from torchao.prototype.dtypes.uint2 import unpack_uint2 +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter @pytest.fixture def bitnet_tensor(): @@ -28,6 +30,29 @@ def test_conversion(bitnet_tensor, dtype): expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) +def _apply_weight_only_uint2_quant(model): + def fn(mod): + mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False) + return mod + + _replace_with_custom_fn_if_matches_filter( + model, + lambda mod: fn(mod), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + +@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) +def test_uint2_quant(input_shape): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.randn(*input_shape).to(device) + m = nn.Sequential(nn.Linear(4, 16)).to(device) + y_ref = m(x) + _apply_weight_only_uint2_quant(m) + y_wo = m(x) + assert y_ref.shape == y_wo.shape + y_compiled = torch.compile(m, fullgraph=True)(x) + + if __name__ == "__main__": - pytest.main() + pytest.main(__file__) diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py index f48d8df5ea..fd2d0a4790 100644 --- a/torchao/prototype/dtypes/bitnet.py +++ b/torchao/prototype/dtypes/bitnet.py @@ -24,8 +24,8 @@ def __init__(self, input_tensor: torch.Tensor, **kwargs): super(BitnetTensor, self).__init__(input_tensor, **kwargs) @staticmethod - def __tensor_unflatten__(flattened, meta): - assert meta is None + def __tensor_unflatten__(flattened, *meta): + # TODO - meta is not None, is it ok? elem = flattened["elem"] return BitnetTensor(elem) @@ -51,7 +51,8 @@ def allowed_subclasses(type): @classmethod def from_float(cls, w: torch.Tensor): - w_int2 = _quantize_int2(w).to(device=w.device) + w_intq = _quantize_int2(w) + w_int2 = w_intq.to(device=w.device) return w_int2 def clone(self): @@ -62,7 +63,6 @@ def copy_(self, src): return self def to(self, *args, **kwargs): - (tensor,) = args if len(args) == 1 and isinstance(args[0], torch.dtype): dtype = args[0] if dtype == torch.int8: @@ -71,23 +71,31 @@ def to(self, *args, **kwargs): return unpack_uint2(self.elem).to(torch.int8).to(dtype) elif dtype == torch.uint8: return unpack_uint2(self.elem).view(torch.uint8) - elif isinstance(tensor, BitnetTensor): + elif isinstance(self, BitnetTensor): return self + if 'device' in kwargs: + device = kwargs['device'] + return BitnetTensor(self.elem.to(device=device)) + return super().to(*args, **kwargs) @implements([torch.ops.aten.mm.default]) def mm(func, args, kwargs): x, weight = args - x = unpack_uint2(x.elem).to(torch.float32) - weight = unpack_uint2(weight.elem).to(torch.float32) + if isinstance(x, BitnetTensor): + x = unpack_uint2(x.elem).to(torch.float32) + if isinstance(weight, BitnetTensor): + weight = unpack_uint2(weight.elem).to(torch.float32) y = torch.mm(x, weight) return y @implements([torch.ops.aten.addmm.default]) def addmm(func, args, kwargs): bias, x, weight = args - x = unpack_uint2(x.elem).to(torch.float32) - weight = unpack_uint2(weight.elem).to(torch.float32) + if isinstance(x, BitnetTensor): + x = unpack_uint2(x.elem).to(torch.float32) + if isinstance(weight, BitnetTensor): + weight = unpack_uint2(weight.elem).to(torch.float32) if bias is not None: bias = bias.to(torch.float32) y = torch.addmm(bias, x, weight) diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index e071e822d9..88e9699a49 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -211,7 +211,8 @@ def equal(func, args, kwargs): @implements([torch.ops.aten.detach.default]) def detach(func, args, kwargs): (tensor,) = args - return tensor.elem.detach() + detached_elem = tensor.elem.detach() + return UInt2Tensor(detached_elem) @implements([torch.ops.aten.to.dtype]) def to_dtype(func, args, kwargs): From c9e9583d4043fd271cb1198ac78d91ac367e3823 Mon Sep 17 00:00:00 2001 From: James Melvin Priyarajan Date: Tue, 18 Jun 2024 00:28:39 +0000 Subject: [PATCH 15/18] fix: CI errors on compile tests --- test/dtypes/test_bitnet.py | 1 + torchao/prototype/dtypes/bitnet.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index b0407143b2..61663e9adf 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -43,6 +43,7 @@ def fn(mod): @pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) def test_uint2_quant(input_shape): + torch.set_float32_matmul_precision("high") device = 'cuda' if torch.cuda.is_available() else 'cpu' x = torch.randn(*input_shape).to(device) m = nn.Sequential(nn.Linear(4, 16)).to(device) diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py index fd2d0a4790..61fb159927 100644 --- a/torchao/prototype/dtypes/bitnet.py +++ b/torchao/prototype/dtypes/bitnet.py @@ -61,7 +61,18 @@ def clone(self): def copy_(self, src): self.elem.copy_(src.elem) return self + + def tolist(self): + data = unpack_uint2(self.elem).tolist() + return data + def __repr__(self): + try: + data = unpack_uint2(self.elem).tolist() + except AssertionError: + data = f"Tensor of shape {self.shape} and dtype {self.elem.dtype}" + return f"BitnetTensor({data}, dtype={self.elem.dtype})" + def to(self, *args, **kwargs): if len(args) == 1 and isinstance(args[0], torch.dtype): dtype = args[0] @@ -143,3 +154,8 @@ def clone(func, args, kwargs): (tensor,) = args return tensor.clone() +@implements([torch.ops.aten.allclose.default]) +def allclose(func, args, kwargs): + (a, b) = args + return torch.allclose(a.elem, b.elem, **kwargs) + From 70412163c05b18fbd86dce9538a28b76b975295c Mon Sep 17 00:00:00 2001 From: James Melvin Priyarajan Date: Tue, 18 Jun 2024 00:54:08 +0000 Subject: [PATCH 16/18] fix: skip tests less than torch 2.4 --- test/dtypes/test_bitnet.py | 4 ++++ test/dtypes/test_uint2.py | 5 +++++ torchao/prototype/dtypes/uint2.py | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index 61663e9adf..63d782b07a 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -4,6 +4,10 @@ from torchao.prototype.dtypes import BitnetTensor from torchao.prototype.dtypes.uint2 import unpack_uint2 from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) @pytest.fixture def bitnet_tensor(): diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index c4ff28a688..664c187712 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -3,6 +3,11 @@ import torch.nn as nn from torchao.prototype.dtypes import UInt2Tensor from torchao.prototype.dtypes.uint2 import unpack_uint2 +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + @pytest.fixture def uint2_tensor(): diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index 88e9699a49..c0e88e94d2 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -232,3 +232,8 @@ def t(func, args, kwargs): unpacked = unpack_uint2(tensor.elem).to(tensor.device) transposed = unpacked.t() return UInt2Tensor(pack_uint2(transposed)) + +@implements([torch.ops.aten.allclose.default]) +def allclose(func, args, kwargs): + tensor, other = args + return torch.allclose(tensor.elem, other.elem) From 5ef3f6b539da6f5fdd46b98a7575ce4f10405102 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 17 Jun 2024 20:17:25 -0700 Subject: [PATCH 17/18] Added pytest fixture --- test/dtypes/test_bitnet.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index 63d782b07a..9fb4b8496e 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -9,6 +9,19 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 + + # setup (currently do nothing) + + # tests will run here + yield + + # teardown + # avoid dynamo cache limit issues + torch._dynamo.reset() + @pytest.fixture def bitnet_tensor(): input_tensor = torch.randint(0, 15, (4,4), dtype=torch.uint8) From ae4ead170cfa1e3b640a3edb32809ce9cc8fe1cc Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 17 Jun 2024 21:21:25 -0700 Subject: [PATCH 18/18] remove tensor core flag --- test/dtypes/test_bitnet.py | 1 - test/dtypes/test_uint2.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index 9fb4b8496e..1abdd0c1ed 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -60,7 +60,6 @@ def fn(mod): @pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) def test_uint2_quant(input_shape): - torch.set_float32_matmul_precision("high") device = 'cuda' if torch.cuda.is_available() else 'cpu' x = torch.randn(*input_shape).to(device) m = nn.Sequential(nn.Linear(4, 16)).to(device) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index 664c187712..4cdfd88baf 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -8,7 +8,6 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) - @pytest.fixture def uint2_tensor(): input_tensor = torch.randint(0, 15, (4,4), dtype = torch.uint8) @@ -32,4 +31,3 @@ def test_conversion(uint2_tensor, dtype): if __name__ == '__main__': pytest.main(__file__) -