Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming fpx to floatx #877

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
72 changes: 36 additions & 36 deletions test/dtypes/test_fpx.py → test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
parametrize,
run_tests,
)
from torchao.dtypes.fpx import (
FpxTensorCoreAQTLayout,
FpxTensorCoreLayoutType,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTLayout,
FloatxTensorCoreLayoutType,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
from torchao.quantization import (
quantize_,
fpx_weight_only,
Expand All @@ -25,71 +25,71 @@


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_FPx_DTYPES = [(3, 2), (2, 2)]
_Floatx_DTYPES = [(3, 2), (2, 2)]


class TestFpxTensorCoreAQTLayout(TestCase):
class TestFloatxTensorCoreAQTLayout(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)

expected = _pack_tc_fpx(x, 6)
expected = _pack_tc_floatx(x, 6)
actual = _pack_tc_fp6(x)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
def test_to_scaled_tc_floatx_compile(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device)

expected = to_scaled_tc_fpx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
expected = to_scaled_tc_floatx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
def test_from_tc_floatx_correctness(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device) * 100

# quantize and dequantize so that the values are exactly representable in FPx
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)
# quantize and dequantize so that the values are exactly representable in Floatx
x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits)

tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits)
actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale)
torch.testing.assert_close(actual, x)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
M, N = 256, 64
nbits = 1 + ebits + mbits
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
scale = torch.randn(M, device=device)

expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
expected = from_scaled_tc_floatx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
from torchao.quantization.quant_primitives import (
choose_qparams_affine_fpx,
quantize_affine_fpx,
choose_qparams_affine_floatx,
quantize_affine_floatx,
)

x = torch.randn(256, 64)
scale = choose_qparams_affine_fpx(x, ebits, mbits)
x = quantize_affine_fpx(x, scale, ebits, mbits)
layout_type = FpxTensorCoreLayoutType(ebits, mbits)
fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert fpx_layout_tensor.device.type == "cuda"
fpx_layout_tensor = fpx_layout_tensor.cpu()
assert fpx_layout_tensor.device.type == "cpu"
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert floatx_layout_tensor.device.type == "cuda"
floatx_layout_tensor = floatx_layout_tensor.cpu()
assert floatx_layout_tensor.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
def test_fpx_weight_only(self, ebits, mbits, bias):
N, OC, IC = 4, 256, 64
Expand All @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout)
instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.dtypes.uintx.uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down
66 changes: 33 additions & 33 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
int_scaled_matmul,
choose_qparams_and_quantize_affine_hqq,
FP8_TYPES,
choose_qparams_affine_fpx,
quantize_affine_fpx,
dequantize_affine_fpx,
choose_qparams_affine_floatx,
quantize_affine_floatx,
dequantize_affine_floatx,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
Expand Down Expand Up @@ -199,10 +199,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
if output_dtype is None:
output_dtype = self.dtype

from torchao.dtypes.fpx import FpxTensorCoreLayoutType
if isinstance(self.layout_type, FpxTensorCoreLayoutType):
from torchao.dtypes.floatx import FloatxTensorCoreLayoutType
if isinstance(self.layout_type, FloatxTensorCoreLayoutType):
int_data, scale = self.layout_tensor.get_plain()
return dequantize_affine_fpx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype)
return dequantize_affine_floatx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype)
else:
data, scale, zero_point = self.layout_tensor.get_plain()
return dequantize_affine(
Expand Down Expand Up @@ -389,8 +389,8 @@ def from_hp_to_fpx(
input_float: torch.Tensor,
layout_type: LayoutType,
):
from torchao.dtypes.fpx import FpxTensorCoreLayoutType
assert isinstance(layout_type, FpxTensorCoreLayoutType), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}"
from torchao.dtypes.floatx import FloatxTensorCoreLayoutType
assert isinstance(layout_type, FloatxTensorCoreLayoutType), f"Only FloatxTensorCoreLayoutType is supported for floatx, got {layout_type}"
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)
# per axis quantization, where axis = 1
Expand All @@ -399,12 +399,12 @@ def from_hp_to_fpx(

ebits, mbits = layout_type.ebits, layout_type.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_fpx(input_float, ebits, mbits)
fpx_unpacked = quantize_affine_fpx(input_float, scale, ebits, mbits)
fpx_packed = layout_type.post_process(fpx_unpacked)
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed = layout_type.post_process(floatx_unpacked)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(fpx_packed, scale, None, layout_type)
layout_tensor = layout_tensor_ctr(floatx_packed, scale, None, layout_type)
return cls(
layout_tensor,
block_size,
Expand Down Expand Up @@ -502,7 +502,7 @@ class MarlinSparseLayoutType(LayoutType):
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
"""Preprocess the input tensor to be in the correct format for the Marlin sparse kernel.
- 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format
- 2º: tensor is injected with 2:4 sparsity
- 2º: tensor is injected with 2:4 sparsity
- 3º: transposes it again because the quantization process will compute the scales for dim=-1

Args:
Expand Down Expand Up @@ -673,8 +673,8 @@ def from_plain(
@register_layout_cls(MarlinSparseLayoutType)
class MarlinSparseAQTLayout(AQTLayout):
"""
Layout storage class for sparse_marlin_24 layout for affine quantized tensor.
Layout storage class for sparse_marlin_24 layout for affine quantized tensor.

Can be used with 4 bits and 8 bits quantization.

Original marlin documentation and information:
Expand Down Expand Up @@ -760,9 +760,9 @@ def __tensor_unflatten__(
def get_plain(self):
from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import
int_data_expanded, scales_expanded = unpack_from_marlin_24(
self.int_data,
self.scale,
self.meta,
self.int_data,
self.scale,
self.meta,
self.original_shape,
self.group_size,
self.num_bits,
Expand Down Expand Up @@ -794,7 +794,7 @@ def from_plain(

if q_w_24.dtype != torch.int32:
raise ValueError("Only `torch.int32` weights are supported.")

in_features, out_features = q_w_24.shape
if in_features % 128 != 0 or out_features != 256 == 0:
raise ValueError(
Expand Down Expand Up @@ -824,11 +824,11 @@ def from_plain(
marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size)

return cls(
marlin_24_q_w_comp, marlin_24_s, zero_point,
marlin_24_q_w_comp, marlin_24_s, zero_point,
meta, layout_type, q_w_24.shape,
group_size, num_bits
)

def get_layout_type(self) -> LayoutType:
return self.layout_type

Expand Down Expand Up @@ -956,7 +956,7 @@ def __repr__(self):
f"scale={scale},\n"
f"transposed={self.transposed}, "
f"layout_type={layout_type})")


@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
Expand Down Expand Up @@ -1308,16 +1308,16 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
y += bias.to(m.dtype)
return y

def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias):
from torchao.dtypes.fpx import FpxTensorCoreLayoutType
def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias):
from torchao.dtypes.floatx import FloatxTensorCoreLayoutType
return (
# input is native float32 tensor
not is_traceable_wrapper_subclass(input_tensor) and
input_tensor.is_floating_point() and
input_tensor.dtype == torch.float16 and
# weight is fpx Tensor
# weight is floatx Tensor
isinstance(weight_tensor, AffineQuantizedTensor) and
isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) and
isinstance(weight_tensor.layout_type, FloatxTensorCoreLayoutType) and
(
# weight is using fp6 quantization
(weight_tensor.layout_type.ebits == 3 and
Expand All @@ -1332,8 +1332,8 @@ def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias):
)
)

def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):
from torchao.dtypes.fpx import _SPLIT_K_MAP
def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias):
from torchao.dtypes.floatx import _SPLIT_K_MAP
from torchao.ops import quant_llm_linear

act = input_tensor
Expand All @@ -1350,7 +1350,7 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):
weight.layout_type.ebits,
weight.layout_type.mbits,
act_reshaped,
weight.layout_tensor.packed_fpx_data,
weight.layout_tensor.packed_floatx_data,
weight.layout_tensor.scale,
splitK=splitK,
)
Expand Down Expand Up @@ -1378,10 +1378,10 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
""" Ensures input tensor is correctly formated for _scaled_mm """
input_scale = input_scale.unsqueeze(-1)

if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])

return input_scale

def _linear_fp_act_fp8_weight_impl(
Expand Down Expand Up @@ -1457,7 +1457,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
workspace_24 = marlin_24_workspace(original_shape[1])

out = marlin_24_gemm(
input_2d, sparse_w_int4, meta, scale,
input_2d, sparse_w_int4, meta, scale,
workspace_24, num_bits, size_m, size_n, size_k
)

Expand All @@ -1476,7 +1476,7 @@ def _register_aqt_quantized_linear_dispatches():
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
(_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),
(_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl),
]:
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
Expand Down
10 changes: 5 additions & 5 deletions torchao/dtypes/fpx/README.md → torchao/dtypes/floatx/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Quant-LLM

This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to FPx and integration with torchao API.
This is a FP16 x Floatx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to Floatx and integration with torchao API.

## Usage

Expand All @@ -13,7 +13,7 @@ from torchao.quantization import (
model = ...
model.half() # not necessary, but recommeneded to maintain accuracy

# for generic FPx EyMz where x = 1 + y + z
# for generic Floatx EyMz where x = 1 + y + z
# fp6 with ebits = 3 and mbits = 2
quantize_(model, fpx_weight_only(3, 2))

Expand All @@ -25,15 +25,15 @@ It's also possible to pre-process the weight and call the kernel directly.

```python
import torch
from torchao.dtypes.fpx import to_scaled_tc_fpx
from torchao.dtypes.floatx import to_scaled_tc_floatx
from torchao.ops import quant_llm_linear

fp32_weight = torch.randn(1024, 512).cuda()
ebits, mbits = 3, 2

# pre-process the weight. this will quantize the weight to FP6 and pack it in a special
# layout for tensor cores. refer to paper for more details.
fp6_weight, scales = to_scaled_tc_fpx(fp32_weight, ebits, mbits)
fp6_weight, scales = to_scaled_tc_floatx(fp32_weight, ebits, mbits)

fp16_act = torch.randn(1, 512).cuda().half()
outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape (1, 1024)
Expand All @@ -48,7 +48,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape

Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [eval.py](../../_models/llama/eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf).

FPx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`.
Floatx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`.

Quantization | wikitext perplexity | tokens/s
--------------------|---------------------|----------
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/floatx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP
Loading
Loading