From 77541c5e899db71689a178b6c30d53130f856839 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 23 May 2024 09:15:46 -0700 Subject: [PATCH] Add a prototype of MX format training and inference Summary: The MX numerical formats are new low precision formats with recent acceptance into the OCP spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This PR adds a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. Currently, we use a reference layout (scale and raw data stored separately) and an emulated matrix multiplication. Test Plan: ``` // tests pytest -s test/prototype/mx_formats/* // benchmarks python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ``` Reviewers: Subscribers: Tasks: Tags: --- README.md | 1 + dev-requirements.txt | 2 + test/prototype/mx_formats/test_custom_cast.py | 388 ++++++++++ test/prototype/mx_formats/test_mx_linear.py | 212 ++++++ test/prototype/mx_formats/test_mx_tensor.py | 267 +++++++ torchao/prototype/mx_formats/README.md | 92 +++ torchao/prototype/mx_formats/__init__.py | 0 .../mx_formats/benchmarks/bench_qdq.py | 161 ++++ torchao/prototype/mx_formats/config.py | 2 + torchao/prototype/mx_formats/constants.py | 48 ++ torchao/prototype/mx_formats/custom_cast.py | 713 ++++++++++++++++++ torchao/prototype/mx_formats/fp_formats.py | 550 ++++++++++++++ torchao/prototype/mx_formats/mx_linear.py | 160 ++++ torchao/prototype/mx_formats/mx_ops.py | 145 ++++ torchao/prototype/mx_formats/mx_tensor.py | 411 ++++++++++ torchao/prototype/mx_formats/utils.py | 7 + 16 files changed, 3159 insertions(+) create mode 100644 test/prototype/mx_formats/test_custom_cast.py create mode 100644 test/prototype/mx_formats/test_mx_linear.py create mode 100644 test/prototype/mx_formats/test_mx_tensor.py create mode 100644 torchao/prototype/mx_formats/README.md create mode 100644 torchao/prototype/mx_formats/__init__.py create mode 100644 torchao/prototype/mx_formats/benchmarks/bench_qdq.py create mode 100644 torchao/prototype/mx_formats/config.py create mode 100644 torchao/prototype/mx_formats/constants.py create mode 100644 torchao/prototype/mx_formats/custom_cast.py create mode 100644 torchao/prototype/mx_formats/fp_formats.py create mode 100644 torchao/prototype/mx_formats/mx_linear.py create mode 100644 torchao/prototype/mx_formats/mx_ops.py create mode 100644 torchao/prototype/mx_formats/mx_tensor.py create mode 100644 torchao/prototype/mx_formats/utils.py diff --git a/README.md b/README.md index 150c67b512..d31b5ceac7 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ To learn more try out our APIs, you can check out API examples in 3. Support for lower precision [dtypes](./torchao/dtypes) such as - [nf4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) without writing custom Triton or CUDA code - [uint4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/uint4.py) + - [MX](https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats) implementing the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), prototype as the hardware support is not available yet 4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees - [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning - [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads diff --git a/dev-requirements.txt b/dev-requirements.txt index 6dadb274aa..68b17dc888 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,6 +9,8 @@ transformers bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers matplotlib pandas +fire # QOL for commandline scripts +tabulate # QOL for printing tables to stdout # Custom CUDA Extensions ninja diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py new file mode 100644 index 0000000000..a1c3f13ee5 --- /dev/null +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -0,0 +1,388 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +import torchao.prototype.mx_formats.config as config +from torch.utils._triton import has_triton +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + F4_E2M1_EXP_BIAS, + F6_E2M3_EXP_BIAS, + F6_E3M2_EXP_BIAS, +) + +from torchao.prototype.mx_formats.custom_cast import ( + f32_to_f4_unpacked, + f32_to_f6_e2m3_unpacked, + f32_to_f6_e3m2_unpacked, + f4_unpacked_to_f32, + f6_e2m3_unpacked_to_f32, + f6_e3m2_unpacked_to_f32, + get_bits, + pack_uint4, + triton_f4_to_bf16, + unpack_uint4, +) + +from torchao.prototype.mx_formats.fp_formats import ( + _assert_equals, + dtype_to_interesting_values, + float4_e2m1_interesting_values, + float6_e2m3_interesting_values, + float6_e3m2_interesting_values, + get_sem_bits, + sem_bits_to_sem_vals, + sem_vals_to_f32, +) + +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +torch.manual_seed(0) + + +@pytest.mark.skip( + reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501 +) +def test_fp32(): + dtype = torch.float + interesting_values = dtype_to_interesting_values[dtype] + for fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, _notes in interesting_values: + _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype) + + +@pytest.mark.skip( + reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501 +) +def test_bf16(): + dtype = torch.bfloat16 + interesting_values = dtype_to_interesting_values[dtype] + for fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, _notes in interesting_values: + _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype) + + +def test_fp16(): + dtype = torch.float16 + interesting_values = dtype_to_interesting_values[dtype] + for fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, _notes in interesting_values: + _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype) + + +def test_float8_e4m3fn(): + dtype = torch.float8_e4m3fn + interesting_values = dtype_to_interesting_values[dtype] + for fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, _notes in interesting_values: + _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype) + + +def test_float8_e5m2(): + dtype = torch.float8_e5m2 + interesting_values = dtype_to_interesting_values[dtype] + for fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, _notes in interesting_values: + _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype) + + +def _sem_enc_to_fp32_val(s_enc, e_enc, m_enc, is_zero, is_denorm, exp_bias): + s_i = 1.0 if s_enc == "0" else -1.0 + if is_zero: + e_i = 0 + m_f = 0.0 + elif is_denorm: + e_i = int(e_enc, 2) - exp_bias + 1 + m_f = 0.0 + cur_pow_of_two = -1 + for m_bit in m_enc: + m_f += int(m_bit, 2) * pow(2, cur_pow_of_two) + cur_pow_of_two -= 1 + else: + e_i = int(e_enc, 2) - exp_bias + m_f = 1.0 + cur_pow_of_two = -1 + for m_bit in m_enc: + m_f += int(m_bit, 2) * pow(2, cur_pow_of_two) + cur_pow_of_two -= 1 + fp32 = s_i * (2**e_i) * m_f + return fp32 + + +def test_float4_e2m1_table(): + for ( + fp32_ref, + _formula, + s_enc, + e_enc, + m_enc, + _label, + ) in float4_e2m1_interesting_values: + is_zero = e_enc == "00" and m_enc == "0" + # normal vs denormal + is_denorm = e_enc == "00" and m_enc == "1" + # get exponent and mantissa + exp_bias = F4_E2M1_EXP_BIAS + fp32 = _sem_enc_to_fp32_val( + s_enc, e_enc, m_enc, is_zero, is_denorm, exp_bias + ) # noqa: E501 + assert abs(fp32_ref - fp32) < 1e-12 + + +def test_float6_e3m2_table(): + for ( + fp32_ref, + _formula, + s_enc, + e_enc, + m_enc, + _label, + ) in float6_e3m2_interesting_values: + is_zero = e_enc == "000" and m_enc == "00" + # normal vs denormal + is_denorm = e_enc == "000" and m_enc != "00" + # get exponent and mantissa + exp_bias = F6_E3M2_EXP_BIAS + fp32 = _sem_enc_to_fp32_val( + s_enc, e_enc, m_enc, is_zero, is_denorm, exp_bias + ) # noqa: E501 + assert abs(fp32_ref - fp32) < 1e-12 + + +def test_float6_e2m3_table(): + for ( + fp32_ref, + _formula, + s_enc, + e_enc, + m_enc, + _label, + ) in float6_e2m3_interesting_values: + is_zero = e_enc == "00" and m_enc == "000" + # normal vs denormal + is_denorm = e_enc == "00" and m_enc != "000" + # get exponent and mantissa + exp_bias = F6_E2M3_EXP_BIAS + fp32 = _sem_enc_to_fp32_val( + s_enc, e_enc, m_enc, is_zero, is_denorm, exp_bias + ) # noqa: E501 + assert abs(fp32_ref - fp32) < 1e-12 + + +# positive float4 vals, in increasing order: +# 0: 0 +# 1: 0.5 +# 2: 1.0 +# 3: 1.5 +# 4: 2.0 +# 5: 3.0 +# 6: 4.0 +# 7: 6.0 +# below we test pos and neg versions of all of these + + +def _test_fp4_case(f32_val, f32_val_ref, f4_enc_ref): + # 1. verify that a fp32 value gets quantized to correct fp4 encoding + # TODO test on cuda + f4_unpacked = f32_to_f4_unpacked(torch.tensor(f32_val)) + s_enc, e_enc, m_enc = get_sem_bits(f4_unpacked, bitwidth=4) + assert s_enc + e_enc + m_enc == f4_enc_ref + + # 2. verify that fp4 value gets dequantized to correct fp32 value + f32_dequantized = f4_unpacked_to_f32(f4_unpacked) + assert f32_val_ref == f32_dequantized.item() + + +def _test_fp4_cases(cases): + # test the exp and mantissa with both values of the sign bit + for s_enc in "0", "1": + s_i = 1.0 if s_enc == "0" else -1.0 + for val, val_ref, em_enc in cases: + _test_fp4_case(s_i * val, s_i * val_ref, s_enc + em_enc) + + +# note: below are written as individual test cases for easy command line +# filtering with pytest, i.e. "-k fp4_0_0" + +# Explanation of tie-to-even test cases: +# 1. read https://stackoverflow.com/q/8981913/ +# From above, tie-to-even rule: if GRS == 100, round up if bit before is a 1, # noqa: E501 +# and round down if it's a 0 +# +# 2. assume 1.mm...m for normals and 0.mm...m for denormals. Since +# fp4 has only one mantissa bit we are always rounding after that bit. So, +# G == 0 for fp4 denormal range, and G == 1 for fp4 normal range. +# +# 3. Therefore, when we have a tie (GRS == 100), we round down for fp4 denormals # noqa: E501 +# and round up for fp4 normals: +# 0.25 -> 0.0 (the only denormal case) +# 0.75 -> 1.0 +# 1.25 -> 1.0 +# 1.75 -> 2.0 +# 2.5 -> 2.0 +# 3.5 -> 4.0 +# 5.0 -> 4.0 + + +def test_fp4_0_0(): + cases = [ + (0.25, 0.0, "000"), # tie to even + (0.1, 0.0, "000"), + (0.0, 0.0, "000"), + # note: -0.1 is tested in the negative zero test + ] + _test_fp4_cases(cases) + + +def test_fp4_0_5(): + cases = [ + (0.6, 0.5, "001"), + (0.5, 0.5, "001"), + (0.4, 0.5, "001"), + ] + _test_fp4_cases(cases) + + +def test_fp4_1_0(): + cases = [ + (1.25, 1.0, "010"), # tie to even + (1.1, 1.0, "010"), + (1.0, 1.0, "010"), + (0.9, 1.0, "010"), + (0.75, 1.0, "010"), # tie to even + ] + _test_fp4_cases(cases) + + +def test_fp4_1_5(): + cases = [ + (1.6, 1.5, "011"), + (1.5, 1.5, "011"), + (1.4, 1.5, "011"), + ] + _test_fp4_cases(cases) + + +def test_fp4_2_0(): + cases = [ + (2.5, 2.0, "100"), # tie to even + (2.1, 2.0, "100"), + (2.0, 2.0, "100"), + (1.9, 2.0, "100"), + (1.75, 2.0, "100"), # tie to even + ] + _test_fp4_cases(cases) + + +def test_fp4_3_0(): + cases = [ + (3.1, 3.0, "101"), + (3.0, 3.0, "101"), + (2.9, 3.0, "101"), + ] + _test_fp4_cases(cases) + + +def test_fp4_4_0(): + cases = [ + (5.0, 4.0, "110"), # tie to even + (4.1, 4.0, "110"), + (4.0, 4.0, "110"), + (3.9, 4.0, "110"), + (3.5, 4.0, "110"), # tie to even + ] + _test_fp4_cases(cases) + + +def test_fp4_6_0(): + cases = [ + (6.1, 6.0, "111"), + (6.0, 6.0, "111"), + (5.9, 6.0, "111"), + ] + _test_fp4_cases(cases) + + +def test_fp4_pack_unpack(): + orig_vals = torch.Tensor([[0.0, 0.5, 4.0, -0.0], [-0.0, 1.0, -6.0, 3.0]]) + orig_vals_f4_unpacked = f32_to_f4_unpacked(orig_vals) + orig_vals_f4_packed = pack_uint4(orig_vals_f4_unpacked) + assert orig_vals_f4_packed.numel() == (orig_vals.numel() / 2) + orig_vals_f4_packed_unpacked = unpack_uint4(orig_vals_f4_packed) + orig_vals_dq = f4_unpacked_to_f32(orig_vals_f4_packed_unpacked) + assert torch.all(orig_vals_dq == orig_vals) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +def test_fp4_triton_unscaled_cast(): + packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") + f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) + f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float) + assert torch.all(torch.eq(f32_ref, f32_triton)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +def test_fp4_triton_scaled_cast(): + size = (256,) + orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 + mxtensor = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4) + + f32_ref = mxtensor.to_dtype(torch.float) + config.use_fp4_custom_triton_dequant_kernel = True + f32_triton = mxtensor.to_dtype(torch.float) + config.use_fp4_custom_triton_dequant_kernel = False + assert torch.all(torch.eq(f32_ref, f32_triton)) + + +@pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2)) +def test_fp6_values(dtype_name): + """ + The fp6 dtypes have 2**6 = 64 unique values each. The test + below tests the f32 -> f6 and f6 -> f32 cast for each value. + + TODO(future PR): also verify rounding tie-to-even works properly. + """ + + for i in range(2**6): + t = torch.tensor(i, dtype=torch.uint8) + bits = get_bits(t.to(torch.int8)) + + # go from bits to f32 ref + if dtype_name == DTYPE_FP6_E2M3: + s_enc, e_enc, m_enc = bits[2], bits[3:5], bits[5:] + elif dtype_name == DTYPE_FP6_E3M2: + s_enc, e_enc, m_enc = bits[2], bits[3:6], bits[6:] + else: + raise AssertionError("unsupported") + s_i, e_i, m_f, special_value = sem_bits_to_sem_vals( + s_enc, e_enc, m_enc, dtype_name + ) + f32_ref = torch.tensor(sem_vals_to_f32(s_i, e_i, m_f, special_value)) + + # test cast to f6 + if dtype_name == DTYPE_FP6_E2M3: + f6 = f32_to_f6_e2m3_unpacked(f32_ref) + elif dtype_name == DTYPE_FP6_E3M2: + f6 = f32_to_f6_e3m2_unpacked(f32_ref) + else: + raise AssertionError("unsupported") + # test that the bits are equivalent to our starting point + torch.testing.assert_close(f6, t, rtol=0, atol=0) + + # test cast back to f32 + if dtype_name == DTYPE_FP6_E2M3: + f32 = f6_e2m3_unpacked_to_f32(f6) + elif dtype_name == DTYPE_FP6_E3M2: + f32 = f6_e3m2_unpacked_to_f32(f6) + else: + raise AssertionError("unsupported") + torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py new file mode 100644 index 0000000000..c62473a259 --- /dev/null +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import pytest + +import torch +import torch.nn as nn +from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES + +from torchao.prototype.mx_formats.mx_linear import ( + MXInferenceLinear, + MXLinear, + swap_linear_with_mx_inference_linear, + swap_linear_with_mx_linear, +) + +from torchao.prototype.mx_formats.utils import compute_error + +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +# trying to outsmart flake8 +__has_cuda = torch.cuda.is_available() +IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) + +torch.manual_seed(2) + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("input_shape", [(2, 4), (1, 2, 4), (1, 1, 2, 4)]) +def test_linear_eager(elem_dtype, bias, input_shape): + """ + Smoke test for training linear module with mx weight + """ + grad_shape = list(input_shape) + grad_shape[-1] = 6 + + m = nn.Sequential( + nn.Linear(4, 6, bias=bias, device="cuda"), + ) + m_mx = copy.deepcopy(m) + block_size = 2 + swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + + x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() + x = copy.deepcopy(x_ref) + g = torch.randn(*grad_shape, device="cuda") + with torch.autocast("cuda", dtype=torch.bfloat16): + y_ref = m(x_ref) + y_mx = m_mx(x) + + y_ref.backward(g) + y_mx.backward(g) + + y_sqnr = compute_error(y_ref, y_mx) + w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad) + x_g_sqnr = compute_error(x_ref.grad, x.grad) + + if elem_dtype is torch.float8_e4m3fn: + assert y_sqnr >= 18.0 + assert w_g_sqnr >= 18.0 + assert x_g_sqnr >= 14.0 + else: + assert y_sqnr >= 8.0 + assert w_g_sqnr >= 10.0 + assert x_g_sqnr >= 8.0 + + +# TODO(future): enable compile support +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_activation_checkpointing(): + input_shape = (2, 4) + grad_shape = (2, 6) + elem_dtype = torch.float8_e4m3fn + + m = nn.Sequential( + nn.Linear(4, 6, bias=True, device="cuda"), + nn.Linear(6, 6, bias=True, device="cuda"), + ) + block_size = 2 + swap_linear_with_mx_linear(m, elem_dtype, block_size) + + x = torch.randn(*input_shape, device="cuda").requires_grad_() + g = torch.randn(*grad_shape, device="cuda") + y = torch.utils.checkpoint.checkpoint(m, x, use_reentrant=False) + y.backward(g) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize("bias", [False, True]) +def test_linear_compile(elem_dtype, bias): + """ + Verify that compile does not change numerics of MX linear fw + bw + """ + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if not IS_CUDA_GE_89: + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + input_shape = (2, 4) + grad_shape = (2, 6) + m_mx = nn.Sequential( + nn.Linear(4, 6, bias=bias, device="cuda"), + ) + block_size = 2 + swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + m_mx_c = copy.deepcopy(m_mx) + m_mx_c = torch.compile(m_mx_c, fullgraph=True) + + x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() + x = copy.deepcopy(x_ref) + g = torch.randn(*grad_shape, device="cuda") + + with torch.autocast("cuda", dtype=torch.bfloat16): + y_ref = m_mx(x_ref) + y = m_mx_c(x) + torch.testing.assert_close(y_ref, y, atol=0, rtol=0) + + y_ref.backward(g) + y.backward(g) + w_g_ref = m_mx[0].weight.grad + w_g = getattr(m_mx_c, "0").weight.grad + # TODO(future): investigate why we can't match with rtol=0 atol=0 + # after moving to torchao repo. Technically compile does not give + # bit exactness guarantees, but there also might be a bug lurking + # around. + torch.testing.assert_close(w_g_ref, w_g, atol=0.02, rtol=0.02) + + x_g_ref = x_ref.grad + x_g = x.grad + # TODO(future): investigate why we can't match with rtol=0 atol=0 + # after moving to torchao repo. Technically compile does not give + # bit exactness guarantees, but there also might be a bug lurking + # around. + torch.testing.assert_close(x_g_ref, x_g, atol=0.02, rtol=0.02) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("input_shape", [(2, 4), (1, 2, 4), (1, 1, 2, 4)]) +def test_inference_linear(elem_dtype, bias, input_shape): + """ + Smoke test for inference linear module with mx weight + """ + m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16)) + m = m.cuda() + m_mx = copy.deepcopy(m) + block_size = 2 + swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + + x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) + y_ref = m(x) + y_mx = m_mx(x) + sqnr = compute_error(y_ref, y_mx) + if elem_dtype is torch.float8_e4m3fn: + assert sqnr >= 20.0 + else: + assert sqnr >= 11.0 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_inference_compile_simple(elem_dtype): + """ + Smoke test for inference compile + """ + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if not IS_CUDA_GE_89: + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) + m = m.cuda() + m_mx = copy.deepcopy(m) + block_size = 2 + swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + m_mx = torch.compile(m_mx, fullgraph="true") + + x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16) + y_ref = m(x) + y_mx = m_mx(x) + sqnr = compute_error(y_ref, y_mx) + if elem_dtype is torch.float8_e4m3fn: + assert sqnr >= 20.0 + else: + assert sqnr >= 14.0 + + +def test_filter_fn(): + m1 = nn.Sequential( + nn.Linear(32, 32), + nn.Linear(32, 32), + ) + m2 = copy.deepcopy(m1) + filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 + + swap_linear_with_mx_linear(m1, torch.float8_e4m3fn, 32, filter_fn) + assert type(m1[0]) == MXLinear + assert type(m1[1]) == torch.nn.Linear + + swap_linear_with_mx_inference_linear( + m2, torch.float8_e4m3fn, 32, filter_fn + ) # noqa: E501 + assert type(m2[0]) == MXInferenceLinear + assert type(m2[1]) == torch.nn.Linear diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py new file mode 100644 index 0000000000..15f34622ac --- /dev/null +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from torchao.prototype.mx_formats import config +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + SUPPORTED_ELEM_DTYPES, +) + +from torchao.prototype.mx_formats.custom_cast import pack_uint4 + +from torchao.prototype.mx_formats.mx_tensor import ( + E8M0_EXPONENT_NAN_VAL, + MXTensor, + to_dtype, +) + +from torchao.prototype.mx_formats.utils import compute_error + +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +# trying to outsmart flake8 +__has_cuda = torch.cuda.is_available() +IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) + +torch.manual_seed(2) + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 + + # setup (currently do nothing) + + # tests will run here + yield + + # teardown + # avoid dynamo cache limit issues + torch._dynamo.reset() + + +def _test_mx(data_hp, elem_dtype, block_size): + data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size) + data_mx_dq = data_mx.to_dtype(data_hp.dtype) + + def assert_sqnr_gt_threshold(orig, new, threshold): + sqnr = compute_error(orig, new) + if torch.all(torch.isnan(sqnr)): + # if both operands are full of zeroes, sqnr is nan and this is ok + # test for this explicitly + assert torch.all(orig == 0) and torch.all(new == 0) + else: + assert sqnr >= threshold + + if elem_dtype is torch.float8_e4m3fn: + assert_sqnr_gt_threshold(data_hp, data_mx_dq, 20.0) + else: + assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_hello_world(elem_dtype): + data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + block_size = 2 + _test_mx(data, elem_dtype, block_size) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_all_zeros(elem_dtype): + data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16) + block_size = 2 + _test_mx(data, elem_dtype, block_size) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_some_zeros(elem_dtype): + data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + data[0, :] = 0.0 + data[:, 2] = 0.0 + block_size = 2 + _test_mx(data, elem_dtype, block_size) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_exponent_nan_in(elem_dtype): + """ + If high precision block values has a NaN, the exponent block + value is set to is NaN + """ + tensor_hp = torch.tensor( + [float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16 + ) + block_size = 2 + tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) + assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL) + assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_exponent_nan_out(elem_dtype): + """ + If block exponent value is NaN, the MX tensor block value is NaN + """ + scale_e8m0_bits = torch.tensor( + [E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda" + ) + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + data_bits = torch.tensor( + [0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda" + ) # noqa: E501 + elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): + data_bits = torch.tensor( + [0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda" + ) # noqa: E501 + elif elem_dtype == DTYPE_FP4: + data_bits = torch.tensor( + [0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda" + ) # noqa: E501 + data_bits = pack_uint4(data_bits) + else: + raise AssertionError("unsupported") + block_size = 2 + tensor_mx = MXTensor( + scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float + ) + tensor_hp = tensor_mx.to_dtype(torch.float) + assert torch.all(torch.isnan(tensor_hp[0:1])) + assert not torch.any(torch.isnan(tensor_hp[2:])) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_ranks(elem_dtype): + """ + The reshaping logic works for various ranks + """ + B = 2 + shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2)) + for s in shapes: + tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16) + _test_mx(tensor_hp, elem_dtype, B) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_block_sizes(elem_dtype): + """ + Smoke test for various block sizes + """ + for B in (1, 2, 32): + if B == 1 and elem_dtype == DTYPE_FP4: + pytest.skip("unsupported configuration") + tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) + _test_mx(tensor_hp, elem_dtype, B) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize("fp4_triton", [False, True]) +def test_transpose(elem_dtype, fp4_triton): + """ + Verify that transposing an MX tensor works + """ + if elem_dtype != DTYPE_FP4 and fp4_triton: + pytest.skip("unsupported configuration") + + tensor_hp = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16) + block_size = 32 + tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) + config.use_fp4_custom_triton_dequant_kernel = fp4_triton + tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() + config.use_fp4_custom_triton_dequant_kernel = False + + tensor_mx_t = tensor_mx.t() + config.use_fp4_custom_triton_dequant_kernel = fp4_triton + tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype) + config.use_fp4_custom_triton_dequant_kernel = False + + assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape + torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_cast_autograd(elem_dtype): + x = torch.arange(8, device="cuda").bfloat16().requires_grad_() + grad = torch.arange(8, device="cuda").bfloat16() * 0.5 + block_size = 8 + x_mx = MXTensor.to_mx(x, elem_dtype, block_size) + x_dq = x_mx.to_dtype(torch.bfloat16) + x_dq.backward(gradient=grad) + torch.testing.assert_close(grad, x.grad, atol=0, rtol=0) + + +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_view(elem_dtype): + x = torch.randn(1, 2, 4) + block_size = 2 + x_mx = MXTensor.to_mx(x, elem_dtype, block_size) + x_mx_2 = x_mx.view(2, 4) # noqa: F841 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("all_zeros", [False, True]) +def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): + """ + Verifies that compile does not change numerics of MX casts + """ + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if not IS_CUDA_GE_89: + # separate ifs because flake8 is outsmarting me + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + + shape = 4, 8 + if not all_zeros: + x = torch.randn(*shape, dtype=hp_dtype, device="cuda") + else: + x = torch.zeros(*shape, dtype=hp_dtype, device="cuda") + block_size = 2 + to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) + + x_mx = MXTensor.to_mx(x, elem_dtype, block_size) + x_mx_c = to_mx_c(x, elem_dtype, block_size) + torch.testing.assert_close( + x_mx._scale_e8m0, + x_mx_c._scale_e8m0, + atol=0, + rtol=0, + ) + torch.testing.assert_close(x_mx._data, x_mx_c._data, atol=0, rtol=0) + + to_dtype_c = torch.compile(to_dtype, fullgraph=True) + + x_mx_dq = to_dtype( + x_mx._data, + x_mx._scale_e8m0, + x_mx._elem_dtype, + x_mx._block_size, + hp_dtype, # noqa: E501 + ) + x_mx_c_dq = to_dtype_c( + x_mx_c._data, + x_mx_c._scale_e8m0, + x_mx_c._elem_dtype, + x_mx_c._block_size, + hp_dtype, + ) + torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md new file mode 100644 index 0000000000..852e2a06d8 --- /dev/null +++ b/torchao/prototype/mx_formats/README.md @@ -0,0 +1,92 @@ +# MX formats with native PyTorch POC + +This is a POC of implementing the MX formats from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch. + +Note that the current version of the code is written for readability and +numerical correctness and not yet for optimal performance. We welcome +contributions on performance improvements. + +Note that there are no BC guarantees at the moment and we plan to evolve +this code as the hardware specifics of MX-accelerated matmuls become +known. + +# Current status + +## user API (subject to change) + +### MXTensor + +This is casts between fp32/bf16 and MX formats implemented in native PyTorch. + +```python +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4 +x = torch.randn(32, 32, device='cuda') + +# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4 +elem_dtype = torch.float8_e4m3fn +# block_size is 32 in the MX spec +block_size = 32 + +# high precision to mx +x_mx = MXTensor.to_mx(x, elem_dtype, block_size) + +# mx back to high precision +x_hp = x_mx.to_dtype(torch.float) +``` + +### MXLinear + +This is a module to do MX training, the MX matmul is currently emulated. + +```python +from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear + +m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() +elem_dtype = torch.float8_e4m3fn +block_size = 32 +swap_linear_with_mx_linear(m, elem_dtype, block_size) + +# training loop (not shown) +``` + +### MXInferenceLinear + +This is a module to do MX inference, weights are in MX and matmul is in high precision. + +```python +from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear + +m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() +elem_dtype = torch.float8_e4m3fn +block_size = 32 +swap_linear_with_mx_inference_linear(m, elem_dtype, block_size) + +# do inference (not shown) +``` + +## accuracy status +* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve +* approximate numerics pass for `MXLinear` and `MXInferenceLinear` on sample inputs +* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo) + +## performance status + +### quant and dequant + +* we have a benchmark of quantizing and dequantizing mxfp8 and mxfp4 tensors with size (1, 4096, 11008) +* latest numbers: https://gist.github.com/vkuzo/83656e4a74777cfc0915de6b27be1ff6 + +## testing and benchmarking + +```bash +# numerical testing of custom fp4 and fp6 casts +pytest test/prototype/mx_formats/test_custom_cast.py +# testing of MXTensor +pytest test/prototype/mx_formats/test_mx_tensor.py +# testing of MXLinear and MXInferenceLinear +pytest test/prototype/mx_formats/test_mx_linear.py + +# run the quant and dequant benchmark +python torchao/prototype/mx_formats/benchmarks/bench_qdq.py +``` diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py new file mode 100644 index 0000000000..6fbfa8b7b5 --- /dev/null +++ b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Benchmarking mx quantize/dequantize +""" + +from typing import Optional + +import fire +import tabulate +import torch +import torch.utils.benchmark as benchmark + +from torch.profiler import profile, ProfilerActivity +from torchao.prototype.mx_formats import config +from torchao.prototype.mx_formats.constants import ( # noqa: E501 + DTYPE_FP4, + SUPPORTED_ELEM_DTYPES, +) + +from torchao.prototype.mx_formats.mx_tensor import MXTensor + + +def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + # Manual warmup + f(*args, **kwargs) + f(*args, **kwargs) + + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "f": f}, # noqa: E501 + ) + measurement = t0.blocked_autorange() + return measurement.mean * 1e6 + + +def run(profile_folder: Optional[str] = None): + headers = [ + "elem_dtype", + "use_fp4_custom_triton_dequant_kernel", + "q_time_us", + "q_mem_bw_tb_s", + "dq_time_us", + "dq_mem_bw_tb_s", + ] + results = [] + + data_hp = torch.randn(1, 4096, 11008, dtype=torch.bfloat16, device="cuda") + + for elem_dtype in SUPPORTED_ELEM_DTYPES: + for use_fp4_custom_triton_dequant_kernel in (False, True): + config.use_fp4_custom_triton_dequant_kernel = ( + use_fp4_custom_triton_dequant_kernel + ) + + if ( + elem_dtype != DTYPE_FP4 + and use_fp4_custom_triton_dequant_kernel # noqa: E501 + ): + # custom_triton_kernels only works for fp4 + continue + + print( + "elem_dtype", + elem_dtype, + "use_fp4_custom_triton_dequant_kernel", + use_fp4_custom_triton_dequant_kernel, + ) + + data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32) + + if not use_fp4_custom_triton_dequant_kernel: + quant = torch.compile(MXTensor.to_mx) + dequant = torch.compile(data_lp.to_dtype) + else: + # As of 2024-04, torch.compile didn't work with the + # handwritten triton kernel, + # crashed on tl.interleave: + # https://github.com/pytorch/pytorch/issues/123967 + # As of 2024-05-24, now there is message asking to convert to + # an opaque custom op: + # https://gist.github.com/vkuzo/0b0b90dca03bdb8e0446e4135644238a # noqa: E501 + # TODO(future): make this better + quant = MXTensor.to_mx + dequant = data_lp.to_dtype + + # warm up + quant(data_hp, elem_dtype, block_size=32) + res = dequant(torch.bfloat16) + + if profile_folder is not None: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + for _ in range(5): + quant(data_hp, elem_dtype, block_size=32) + dequant(torch.bfloat16) + prof.export_chrome_trace( + profile_folder + + f"/mx_qdq_{elem_dtype}_{use_fp4_custom_triton_dequant_kernel}.json" # noqa: E501 + ) + + q_execution_time_us = benchmark_torch_function_in_microseconds( + quant, data_hp, elem_dtype, block_size=32 + ) + dq_execution_time_us = benchmark_torch_function_in_microseconds( + dequant, torch.bfloat16 + ) + print(f"q time: {q_execution_time_us} us") + print(f"dq time: {dq_execution_time_us} us") + + # memory reads per element: + byte_per_stored_element = 1.0 # fp8 or 2xfp4 + byte_per_stored_exp_element = 1.0 # e8m0 + byte_per_dequantized_element = 2.0 # bfloat16 + mem_reads_writes_bytes = ( + # read raw data + (data_lp._data.numel() * byte_per_stored_element) + + + # read exponent + (data_lp._scale_e8m0.numel() * byte_per_stored_exp_element) + + + # write dequant + (res.numel() * byte_per_dequantized_element) + ) + # note: the above also works for quant, with reads/writes in + # reverse + + q_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( + q_execution_time_us / 1e6 + ) + dq_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( + dq_execution_time_us / 1e6 + ) + print(f"q mem bw: {q_mem_bw_tb_s} TB/s") + print(f"dq mem bw: {dq_mem_bw_tb_s} TB/s") + + results.append( + ( + elem_dtype, + use_fp4_custom_triton_dequant_kernel, + q_execution_time_us, + q_mem_bw_tb_s, + dq_execution_time_us, + dq_mem_bw_tb_s, + ) + ) + config.use_fp4_custom_triton_dequant_kernel = False + + torch._dynamo.reset() + + print(tabulate.tabulate(results, headers=headers, floatfmt=".2f")) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py new file mode 100644 index 0000000000..3e7e03d8f6 --- /dev/null +++ b/torchao/prototype/mx_formats/config.py @@ -0,0 +1,2 @@ +# If True, uses a custom triton kernel for fp4 dequantize +use_fp4_custom_triton_dequant_kernel = False diff --git a/torchao/prototype/mx_formats/constants.py b/torchao/prototype/mx_formats/constants.py new file mode 100644 index 0000000000..9189a8a39c --- /dev/null +++ b/torchao/prototype/mx_formats/constants.py @@ -0,0 +1,48 @@ +import torch + +# This is conceptually an enum of non-core dtypes +# if someone has time to verify torch.compile compatibility, it could be made +# into an enum +DTYPE_FP4 = "fp4_e2m1" +DTYPE_FP6_E3M2 = "fp6_e3m2" +DTYPE_FP6_E2M3 = "fp6_e2m3" + +# Supported element dtypes +SUPPORTED_ELEM_DTYPES = [ + torch.float8_e4m3fn, + torch.float8_e5m2, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + DTYPE_FP4, +] + +F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 +F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0 + +F8E4M3_MAX_POW2 = 8 # 256 +F8E5M2_MAX_POW2 = 15 # 32768 +F6_E2M3_MAX_POW2 = 2 # 4 +F6_E3M2_MAX_POW2 = 4 # 16 +F4_E2M1_MAX_POW2 = 2 # 4 + +E8M0_EXPONENT_BIAS = 127 +E8M0_EXPONENT_NAN_VAL = 255 + +F32_EXP_BIAS = 127 +F6_E2M3_EXP_BIAS = 1 +F6_E3M2_EXP_BIAS = 3 +F4_E2M1_EXP_BIAS = 1 + +F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) + +F6_E2M3_MAX = 7.5 +F6_E2M3_MIN_NORMAL = 1.0 +F6_E2M3_MAX_INT = 31 # integer corresponding to 0b00011111 + +F6_E3M2_MAX = 28.0 +F6_E3M2_MIN_NORMAL = 0.25 +F6_E3M2_MAX_INT = 31 # integer corresponding to 0b00011111 + +F4_E2M1_MAX = 6.0 +F4_E2M1_MIN_NORMAL = 1.0 +F4_E2M1_MAX_INT = 7 diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py new file mode 100644 index 0000000000..60aaa336ba --- /dev/null +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -0,0 +1,713 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import struct + +import numpy as np + +import torch +from torch.utils._triton import has_triton + +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +# TODO(future): if needed, make the below work on previous PyTorch versions, +# just need to hunt down the previous location of `libdevice`. An assert +# at the callsite prevents usage of this on unsupported versions. +if TORCH_VERSION_AFTER_2_4: + from torch._inductor.runtime.triton_helpers import libdevice + +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + E8M0_EXPONENT_BIAS, + E8M0_EXPONENT_NAN_VAL, + F32_EXP_BIAS, + F4_E2M1_EXP_BIAS, + F4_E2M1_MAX, + F4_E2M1_MAX_INT, + F4_E2M1_MIN_NORMAL, + F6_E2M3_EXP_BIAS, + F6_E2M3_MAX, + F6_E2M3_MAX_INT, + F6_E2M3_MIN_NORMAL, + F6_E3M2_EXP_BIAS, + F6_E3M2_MAX, + F6_E3M2_MAX_INT, + F6_E3M2_MIN_NORMAL, +) + + +def get_bits(x: torch.Tensor) -> str: + bits_per_byte = 8 + # Numpy has a nice function to get the string representation of binary. + # Since we are using ints as views of floats, need to specify the width + # to avoid numpy from using two's complement for negative numbers. + return np.binary_repr( + x.cpu().numpy(), width=x.element_size() * bits_per_byte + ) # noqa: E501 + + +EBITS_F32, MBITS_F32 = 8, 23 +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 +EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 +EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 + +DENORM_F32TOF4_EXP = ( + # exp bias conversion between formats + (F32_EXP_BIAS - F4_E2M1_EXP_BIAS) + # mantissa length difference between formats + + (MBITS_F32 - MBITS_F4_E2M1) + # add one to encoded exponent for denormalized numbers + + 1 +) +DENORM_F32TOF4_MASK_INT = DENORM_F32TOF4_EXP << MBITS_F32 +# reinterpret int32 as float32 in Python +# see https://stackoverflow.com/a/34446112/1058521 +DENORM_F32TOF4_MASK_FLOAT = struct.unpack( + "!f", struct.pack("!I", DENORM_F32TOF4_MASK_INT) +)[0] + +DENORM_F32TOF6_E2M3_EXP = ( + # exp bias conversion between formats + (F32_EXP_BIAS - F6_E2M3_EXP_BIAS) + # mantissa length difference between formats + + (MBITS_F32 - MBITS_F6_E2M3) + # add one to encoded exponent for denormalized numbers + + 1 +) +DENORM_F32TOF6_E2M3_MASK_INT = DENORM_F32TOF6_E2M3_EXP << MBITS_F32 +# reinterpret int32 as float32 in Python +# see https://stackoverflow.com/a/34446112/1058521 +DENORM_F32TOF6_E2M3_MASK_FLOAT = struct.unpack( + "!f", struct.pack("!I", DENORM_F32TOF6_E2M3_MASK_INT) +)[0] + +DENORM_F32TOF6_E3M2_EXP = ( + # exp bias conversion between formats + (F32_EXP_BIAS - F6_E3M2_EXP_BIAS) + # mantissa length difference between formats + + (MBITS_F32 - MBITS_F6_E3M2) + # add one to encoded exponent for denormalized numbers + + 1 +) +DENORM_F32TOF6_E3M2_MASK_INT = DENORM_F32TOF6_E3M2_EXP << MBITS_F32 +# reinterpret int32 as float32 in Python +# see https://stackoverflow.com/a/34446112/1058521 +DENORM_F32TOF6_E3M2_MASK_FLOAT = struct.unpack( + "!f", struct.pack("!I", DENORM_F32TOF6_E3M2_MASK_INT) +)[0] + +# +# magic value to add during the normal path +# TODO document this better +# + +# c++ code e5m2: +# f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; +# 0xFFFFF is 1111 1111 1111 1111 1111, 20 ones, 20 = 23 - 3 = 23 - 2 - 1 + +# c++ code e4m3: +# f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; +# 0x7FFFF is 0111 1111 1111 1111 1111, 19 ones, 19 = 23 - 4 = 23 - 3 - 1 + +MAGIC_ADDER_F4_E2M1 = 0x1FFFFF # 21 ones +MAGIC_ADDER_F6_E2M3 = 0x7FFFF # 19 ones +MAGIC_ADDER_F6_E3M2 = 0xFFFFF # 20 ones + +# c++ code named vars +# f_bits += ((uint32_t)(f8_exp_bias - f32_exp_bias) << f32_mbits) + MAGIC_ADDER; # noqa: E501 + +SIGN_MASK_F4 = 0x8 # 1000 +SIGN_MASK_F6_E2M3 = 0x20 # 100000 +SIGN_MASK_F6_E3M2 = 0x20 # 100000 + +MANTISSA_MASK_F4 = 0x1 # 0001 +MANTISSA_MASK_F6_E2M3 = 0x7 # 000111 +MANTISSA_MASK_F6_E3M2 = 0x3 # 000011 + +ZERO_BITS_F32 = 0x0 +ZERO_POINT_FIVE_BITS_F32 = 0x3F000000 + + +def _f32_to_f4_or_f6_unpacked( + x, + max_normal, + min_normal, + denorm_mask_float, + denorm_mask_int, + ebits, + mbits, + exp_bias, + magic_adder, + max_int, + sign_mask, +): + """ + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in the fp6_e2m3 or fp6_e3m2 encoding + + Note: there is no special values (NaN, inf) support in this code as the + OCP spec does not define special values for fp6 and fp4 dtypes. + + Code below is an adaptation of https://fburl.com/code/ciwofcg4 for f4/f6 + + Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 + Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 + """ + assert x.dtype == torch.float + + # save the sign + # Note that we have torch.uint32, but some ops like cpu bit shifts + # do not work on it. So, we stay in int32. + x = x.view(torch.int32) + sign = x & 0x80000000 + + # set everything to positive, will add sign back at the end + x = x ^ sign + + # TODO: can the branch floating point comparisons below be done without + # converting to float? probably but need to verify + x = x.view(torch.float) + + # rewrite saturate/denorm/norm branches without explicit data dependent + # control flow, to be more compiler friendly + saturate_mask = x >= max_normal + denormal_mask = torch.logical_and( + torch.logical_not(saturate_mask), x < min_normal + ) # noqa: E501 + normal_mask = torch.logical_not( + torch.logical_or(saturate_mask, denormal_mask) + ) # noqa: E501 + + # + # branch 1: saturate to max val - handled later in the code which combines + # the branches + # + + # + # branch 2: to conversion to denormal as well as rounding up to normal + # + denormal_x = x + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # + # branch 3: stay in normal range, adjust the exponent and round + # + normal_x = x.view(torch.int32) + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - mbits) + normal_x = normal_x.to(torch.uint8) + + # + # combine the branches + # + x = torch.full_like(x, max_int, dtype=torch.uint8) + x = torch.where(denormal_mask, denormal_x, x) + x = torch.where(normal_mask, normal_x, x) + + # add sign back + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) + sign_lp = sign_lp.to(torch.uint8) + # Right shift of a negative signed integer can fill the least significant + # bits with either 1s or 0s, depending on the implementation. Since PyTorch + # doesn't have an uint32 dtype, we mask out these bits to get just the + # f4 sign bit + sign_lp = sign_lp & sign_mask + x = x | sign_lp + + return x.to(torch.uint8) + + +def f32_to_f4_unpacked(x): + """ + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, with bits 0-3 empty and + bits 4-7 in fp4_e2m1 + """ + return _f32_to_f4_or_f6_unpacked( + x, + F4_E2M1_MAX, + F4_E2M1_MIN_NORMAL, + DENORM_F32TOF4_MASK_FLOAT, + DENORM_F32TOF4_MASK_INT, + EBITS_F4_E2M1, + MBITS_F4_E2M1, + F4_E2M1_EXP_BIAS, + MAGIC_ADDER_F4_E2M1, + F4_E2M1_MAX_INT, + SIGN_MASK_F4, + ) + + +def f32_to_f6_e2m3_unpacked(x): + """ + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and + bits 2-7 in fp6_e2m3 + """ + return _f32_to_f4_or_f6_unpacked( + x, + F6_E2M3_MAX, + F6_E2M3_MIN_NORMAL, + DENORM_F32TOF6_E2M3_MASK_FLOAT, + DENORM_F32TOF6_E2M3_MASK_INT, + EBITS_F6_E2M3, + MBITS_F6_E2M3, + F6_E2M3_EXP_BIAS, + MAGIC_ADDER_F6_E2M3, + F6_E2M3_MAX_INT, + SIGN_MASK_F6_E2M3, + ) + + +def f32_to_f6_e3m2_unpacked(x): + """ + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and + bits 2-7 in fp6_e3m2 + """ + return _f32_to_f4_or_f6_unpacked( + x, + F6_E3M2_MAX, + F6_E3M2_MIN_NORMAL, + DENORM_F32TOF6_E3M2_MASK_FLOAT, + DENORM_F32TOF6_E3M2_MASK_INT, + EBITS_F6_E3M2, + MBITS_F6_E3M2, + F6_E3M2_EXP_BIAS, + MAGIC_ADDER_F6_E3M2, + F6_E3M2_MAX_INT, + SIGN_MASK_F6_E3M2, + ) + + +def _f4_or_f6_unpacked_to_f32(x: torch.Tensor, lp_dtype_name: str): + """ + Input: torch.Tensor of dtype uint8, with bits 0-3 empty and bits 4-7 + containing an fp4_e2m1 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + + TODO(future): check if LUT for everything is faster than bit shifting, + especially for fp4. + """ + assert x.dtype == torch.uint8 + + if lp_dtype_name == DTYPE_FP4: + sign_mask = SIGN_MASK_F4 + ebits = EBITS_F4_E2M1 + mbits = MBITS_F4_E2M1 + exp_bias = F4_E2M1_EXP_BIAS + mantissa_mask = MANTISSA_MASK_F4 + elif lp_dtype_name == DTYPE_FP6_E2M3: + sign_mask = SIGN_MASK_F6_E2M3 + ebits = EBITS_F6_E2M3 + mbits = MBITS_F6_E2M3 + exp_bias = F6_E2M3_EXP_BIAS + mantissa_mask = MANTISSA_MASK_F6_E2M3 + elif lp_dtype_name == DTYPE_FP6_E3M2: + sign_mask = SIGN_MASK_F6_E3M2 + ebits = EBITS_F6_E3M2 + mbits = MBITS_F6_E3M2 + exp_bias = F6_E3M2_EXP_BIAS + mantissa_mask = MANTISSA_MASK_F6_E3M2 + else: + raise AssertionError(f"unsupported lp_dtype_name {lp_dtype_name}") + + # save the sign + sign_lp = x & sign_mask + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_lp + + # + # 1. Calculate zero mask + # + zero_mask = x_pos == 0 + + # + # 2. Calculate the denormal path mask + # + denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) + + # + # 3. Calculate the normal path + # + + # calculate the new exponent and shift it to bits 2:9 of the result + exp_biased_lp = x_pos >> mbits + exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS + exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 + + # shift the mantissa to bits 10:32 of the result + mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) + mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) + result = exp_biased_f32 | mantissa_f32 + + # + # 4. Add the zero and denormal casts to the already casted normal path + # + result[zero_mask] = ZERO_BITS_F32 + # Note: for now the denormal path cast is written for readability and + # numerical correctness. There is likely a way to optimize the performance, + # I just haven't had time to look into it. + if lp_dtype_name == DTYPE_FP4: + result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32 + + elif lp_dtype_name == DTYPE_FP6_E2M3: + # Only 7 possible values, just do a LUT + # Note: calculate the booleans first because we are modifying + # this variable inplace. + is_val1 = mantissa_lp_int32 == 1 + is_val2 = mantissa_lp_int32 == 2 + is_val3 = mantissa_lp_int32 == 3 + is_val4 = mantissa_lp_int32 == 4 + is_val5 = mantissa_lp_int32 == 5 + is_val6 = mantissa_lp_int32 == 6 + is_val7 = mantissa_lp_int32 == 7 + mantissa_lp_int32[is_val1] = 0x3E000000 # 0.125 + mantissa_lp_int32[is_val2] = 0x3E800000 # 0.25 + mantissa_lp_int32[is_val3] = 0x3EC00000 # 0.375 + mantissa_lp_int32[is_val4] = 0x3F000000 # 0.5 + mantissa_lp_int32[is_val5] = 0x3F200000 # 0.625 + mantissa_lp_int32[is_val6] = 0x3F400000 # 0.75 + mantissa_lp_int32[is_val7] = 0x3F600000 # 0.875 + result = torch.where(denormal_mask, mantissa_lp_int32, result) + + elif lp_dtype_name == DTYPE_FP6_E3M2: + # Only 3 possible values, just do a LUT + # Note: calculate the booleans first because we are modifying + # this variable inplace. + is_val1 = mantissa_lp_int32 == 1 + is_val2 = mantissa_lp_int32 == 2 + is_val3 = mantissa_lp_int32 == 3 + mantissa_lp_int32[is_val1] = 0x3D800000 # 0.0625 + mantissa_lp_int32[is_val2] = 0x3E000000 # 0.125 + mantissa_lp_int32[is_val3] = 0x3E400000 # 0.1875 + result = torch.where(denormal_mask, mantissa_lp_int32, result) + else: + raise AssertionError(f"unsupported lp_dtype_name {lp_dtype_name}") + + # add sign back + sign_f32 = sign_lp.to(torch.int32) << ( + MBITS_F32 - mbits + EBITS_F32 - ebits + ) # noqa: E501 + result = result | sign_f32 + + return result.view(torch.float) + + +def f4_unpacked_to_f32(x: torch.Tensor): + """ + Input: torch.Tensor of dtype uint8, with bits 0-3 empty and bits 4-7 + containing an fp4_e2m1 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + """ + return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP4) + + +def f6_e2m3_unpacked_to_f32(x: torch.Tensor): + """ + Input: torch.Tensor of dtype uint8, with bits 0-1 empty and bits 2-7 + containing an fp6_e3m2 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + """ + return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E2M3) + + +def f6_e3m2_unpacked_to_f32(x: torch.Tensor): + """ + Input: torch.Tensor of dtype uint8, with bits 0-1 empty and bits 2-7 + containing an fp6_e3m2 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + """ + return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E3M2) + + +if has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _fp4_packed_to_bf16(x_packed): + """ + Input: a tensor of packed fp4 values + Output: a tensor of bfloat16 values + """ + + # low-bits: original location 0:3 + # high-bits: original location 4:7 + x_low_bits = x_packed >> 4 + x_high_bits = x_packed & 0xF + x = tl.interleave(x_low_bits, x_high_bits) + + # cast logic below + # output = x_unpacked.to(tl.float32) + + # save the sign + sign_f4 = x & SIGN_MASK_F4 + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_f4 + + # Special case zero + zero_mask = x_pos == 0 + + # There is only one denormal value in fp4: s001, which is 0.5 in f32 + # Special case it. + # TODO(later): will it be faster to repeat this for all 8 positive + # values instead of the bit manipulations? + denormal_mask = x_pos == 1 + + # calculate the new exponent and shift it to bits 2:9 of the result + exp_biased_f4 = x_pos >> MBITS_F4_E2M1 + exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS + exp_biased_f32 = exp_biased_f32.to(tl.int32) << MBITS_F32 + + # shift the mantissa to bits 10:32 of the result + mantissa_f4 = x_pos & MANTISSA_MASK_F4 + mantissa_f32 = mantissa_f4.to(tl.int32) << (MBITS_F32 - MBITS_F4_E2M1) + output = mantissa_f32 + + # combine the pieces + result = exp_biased_f32 | mantissa_f32 + # result[zero_mask] = ZERO_BITS_F32 + result = tl.where(zero_mask, ZERO_BITS_F32, result) + # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32 + result = tl.where(denormal_mask, ZERO_POINT_FIVE_BITS_F32, result) + + # add sign back + sign_f32 = sign_f4.to(tl.int32) << ( + MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1 + ) + result = result | sign_f32 + + # The bit shifting above is for float32, so for now we + # bitcast to float32 and then regular cast to bfloat16 + # TODO(later): it should be pretty easy to cast directly to bf16, just + # need to adjust the mbits/ebits/special values. Perf impact is likely + # to be small as we would not be chaning memory access patterns. + output = result.to(tl.float32, bitcast=True) + output = output.to(tl.bfloat16) + return output + + @triton.jit + def triton_f4_to_bf16_kernel( + x_ptr, + output_ptr, + n_elements_in, + BLOCK_SIZE_IN: tl.constexpr, + ): + pid = tl.program_id(axis=0) + n_elements_out = n_elements_in * 2 + BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 + + block_start_in = pid * BLOCK_SIZE_IN + offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) + + mask_in = offsets_in < n_elements_in + + # packed uint8 + x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) + output = _fp4_packed_to_bf16(x_packed) + + # set up output offsets + block_start_out = pid * BLOCK_SIZE_OUT + offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) + mask_out = offsets_out < n_elements_out + + tl.store(output_ptr + offsets_out, output, mask=mask_out) + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_IN": 128}), + triton.Config({"BLOCK_SIZE_IN": 256}), + triton.Config({"BLOCK_SIZE_IN": 512}), + triton.Config({"BLOCK_SIZE_IN": 1024}), + triton.Config({"BLOCK_SIZE_IN": 2048}), + ], + key=["n_elements_in"], + ) + @triton.jit + def triton_f4_to_scaled_bf16_kernel( + x_ptr, + s_ptr, + output_ptr, + n_elements_in, + mx_block_size: tl.constexpr, + BLOCK_SIZE_IN: tl.constexpr, + ): + pid = tl.program_id(axis=0) + n_elements_out = n_elements_in * 2 + n_elements_s = n_elements_out // 32 + + BLOCK_SIZE_S: tl.constexpr = BLOCK_SIZE_IN // 16 + BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 + + block_start_in = pid * BLOCK_SIZE_IN + offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) + mask_in = offsets_in < n_elements_in + # packed uint8 + x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) + output = _fp4_packed_to_bf16(x_packed) + + # load scale + block_start_s = pid * BLOCK_SIZE_S + offsets_s = block_start_s + tl.arange(0, BLOCK_SIZE_S) + mask_s = offsets_s < n_elements_s + s = tl.load(s_ptr + offsets_s, mask=mask_s) + + # create the scale in bf16 + s_offset = s.to(tl.int16) - E8M0_EXPONENT_BIAS + s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) + s_fp = tl.where(s != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan")) + + # multiply output by scale + # TODO(later): see if manipulating the exponent instead of fp + # multiplication is going to give a significant speedup + output = tl.reshape( + output, (BLOCK_SIZE_OUT // mx_block_size, mx_block_size) + ) # noqa: E501 + s_fp = tl.reshape(s_fp, (BLOCK_SIZE_S // 1, 1)) + output = output * s_fp + output = tl.reshape(output, (BLOCK_SIZE_OUT,)) + + # set up output offsets + block_start_out = pid * BLOCK_SIZE_OUT + offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) + mask_out = offsets_out < n_elements_out + + tl.store(output_ptr + offsets_out, output, mask=mask_out) + +else: + + def triton_f4_to_bf16_kernel( + x_ptr, + output_ptr, + n_elements_in, + BLOCK_SIZE_IN, + ): + raise AssertionError("unsupported without triton") + + def triton_f4_to_scaled_bf16_kernel( + x_ptr, + s_ptr, + output_ptr, + n_elements_in, + mx_block_size, + BLOCK_SIZE_IN, + ): + raise AssertionError("unsupported without triton") + + +def triton_f4_to_bf16(x: torch.Tensor): + """ + Input: a tensor of packed fp4 values + Output: a tensor of bfloat16 values + + Note: this function is only used in testing, so we can test + the numerical correctness of the cast without the scaling. + """ + new_shape = (*x.shape[:-1], x.shape[-1] * 2) + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda + n_elements_in = x.numel() + grid = lambda meta: ( # noqa: E731 + triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), + ) # noqa: E731,E501 + triton_f4_to_bf16_kernel[grid](x, output, n_elements_in, BLOCK_SIZE_IN=512) + return output + + +def triton_f4_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +): + """ + Input: a tensor of packed fp4 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + assert TORCH_VERSION_AFTER_2_4, "unsupported" + new_shape = (*x.shape[:-1], x.shape[-1] * 2) + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda + n_elements_in = x.numel() + grid = lambda meta: ( # noqa: E731 + triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), + ) + triton_f4_to_scaled_bf16_kernel[grid]( + x, s_e8m0, output, n_elements_in, mx_block_size + ) + return output + + +# pack/unpack code copy-pasted from +# https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/uint4.py + + +def down_size(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + + +def up_size(size): + return (*size[:-1], size[-1] * 2) + + +def unpack_uint4(uint8_data) -> torch.Tensor: + """Get the original weight from the normalized float weight format""" + assert uint8_data.is_contiguous() + + shape = uint8_data.shape + + # since we are using uint8 we will decode 2 entries per byte + # Shift elements down 4 and select out the bottom 4 bits + # + # Note: known slow with triton + # * currently generates two kernels with a cat in between + # * after https://github.com/pytorch/pytorch/pull/123278 lands I + # verified that we get a single triton kernel, but that is even slower + # than the two kernels before this PR + # * TODO add a microbenchmark of just the cast and profile this + first_elements = (uint8_data >> 4).to(torch.uint8) + second_elements = (uint8_data & 0b1111).to(torch.uint8) + unpacked = torch.stack([first_elements, second_elements], dim=-1).view( + up_size(shape) + ) + + # trying Bert Maher's suggestion + # 2024-04-04: this works in unit tests but is broken on LLaMa 7B FFN with + # ptxas /tmp/tmp84wp7lea.ptx, line 227; error : Unexpected instruction types specified for 'sub' # noqa: E501 + # which seems to be the same issue as https://github.com/pytorch/pytorch/issues/118589 # noqa: E501 + # TODO(later): try removing subtractions from our cast to see if we can work around # noqa: E501 + # shift_tensor = torch.tensor([4, 0], dtype=torch.uint8, device=uint8_data.device) # noqa: E501 + # unpacked = (uint8_data.reshape(-1)[::, None] >> shift_tensor) & 0b1111 + # unpacked = unpacked.view(up_size(shape)) + + return unpacked + + +def pack_uint4(uint8_data) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) diff --git a/torchao/prototype/mx_formats/fp_formats.py b/torchao/prototype/mx_formats/fp_formats.py new file mode 100644 index 0000000000..2dc518add0 --- /dev/null +++ b/torchao/prototype/mx_formats/fp_formats.py @@ -0,0 +1,550 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A helper script to summarize the key numerical values of various floating +point formats relevant to the MX spec. +""" + +import math +from typing import Tuple + +import tabulate + +import torch + +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, +) + +from torchao.prototype.mx_formats.custom_cast import get_bits + +dtype_to_bitwidth = { + torch.float: 32, + torch.bfloat16: 16, + torch.float16: 16, + torch.float8_e4m3fn: 8, + torch.float8_e5m2: 8, + DTYPE_FP6_E3M2: 6, + DTYPE_FP6_E2M3: 6, +} +dtype_to_sem_len = { + torch.float: (1, 8, 23), + torch.bfloat16: (1, 8, 7), + torch.float16: (1, 5, 10), + torch.float8_e4m3fn: (1, 4, 3), + torch.float8_e5m2: (1, 5, 2), + # the line below is currently representing fp4 with bits 0:3 empty and + # bits 4:7 containing the fp4 encoding + # TODO(future): clean this up + torch.uint8: (1, 2, 1), +} +# bias = 2 ** (exp_bitwidth - 1) - 1 +dtype_to_exp_bias = { + torch.float: 127, + torch.bfloat16: 127, + torch.float16: 15, + torch.float8_e4m3fn: 7, + torch.float8_e5m2: 15, + DTYPE_FP6_E2M3: 1, + DTYPE_FP6_E3M2: 3, +} +dtype_to_int_dtype = { + torch.float: torch.int32, + torch.float16: torch.int16, + torch.bfloat16: torch.int16, + torch.float8_e4m3fn: torch.int8, + torch.float8_e5m2: torch.int8, + # for fp4 + # TODO(future): clean it up + torch.uint8: torch.uint8, +} + +# format: +# { +# dtype: [ +# [ +# ref_f32_value, sign_encoding, exp_encoding, mantissa_encoding, +# description, +# ], +# ..., +# ], +# ..., +# } +dtype_to_interesting_values = { + torch.float: [ + # zero and neg zero + (0.0, "0", "0" * 8, "0" * 23, "zero"), + (-0.0, "1", "0" * 8, "0" * 23, "zero_neg"), + # special values + (float("nan"), "0", "1" * 8, "1" + "0" * 22, "nan"), + (float("inf"), "0", "1" * 8, "0" * 23, "inf"), + (float("-inf"), "1", "1" * 8, "0" * 23, "inf_neg"), + # values below verified with from https://www.h-schmidt.net/FloatConverter/IEEE754.html # noqa: E501 + # largest normal + ( + 3.402823466385288598117042e38, + "0", + "1" * 7 + "0", + "1" * 23, + "largest_norm", + ), # noqa: E501 + ( + -3.402823466385288598117042e38, + "1", + "1" * 7 + "0", + "1" * 23, + "largest_norm_neg", + ), + # smallest normal + ( + 1.175494350822287507968737e-38, + "0", + "0" * 7 + "1", + "0" * 23, + "smallest_norm", + ), # noqa: E501 + ( + -1.175494350822287507968737e-38, + "1", + "0" * 7 + "1", + "0" * 23, + "smallest_norm_neg", + ), + # largest denormal + ( + 1.175494210692441075487029e-38, + "0", + "0" * 8, + "1" * 23, + "largest_denorm", + ), # noqa: E501 + ( + -1.175494210692441075487029e-38, + "1", + "0" * 8, + "1" * 23, + "largest_denorm_neg", + ), # noqa: E501 + # smallest denormal + ( + 1.401298464324817070923730e-45, + "0", + "0" * 8, + "0" * 22 + "1", + "smallest_denorm", + ), + ( + -1.401298464324817070923730e-45, + "1", + "0" * 8, + "0" * 22 + "1", + "smallest_denorm_neg", + ), + # positive and negative value + (30.0, "0", "10000011", "1" * 3 + "0" * 20, "random_pos"), + (-24.0, "1", "10000011", "1" + "0" * 22, "random_neg"), + ], + torch.bfloat16: [ + # zero and neg zero + (0.0, "0", "0" * 8, "0" * 7, "zero"), + (-0.0, "1", "0" * 8, "0" * 7, "zero_neg"), + # special values + (float("nan"), "0", "1" * 8, "1" + "0" * 6, "nan"), + (float("inf"), "0", "1" * 8, "0" * 7, "inf"), + (float("-inf"), "1", "1" * 8, "0" * 7, "inf_neg"), + # values below checked with TODO + # largest normal + (3.38953e38, "0", "1" * 7 + "0", "1" * 7, "largest_norm"), + (-3.38953e38, "1", "1" * 7 + "0", "1" * 7, "largest_norm_neg"), + # smallest normal + (1.17549e-38, "0", "0" * 7 + "1", "0" * 7, "smallest_norm"), + (-1.17549e-38, "1", "0" * 7 + "1", "0" * 7, "smallest_norm_neg"), + # largest denormal + (1.16631e-38, "0", "0" * 8, "1" * 7, "largest_denorm"), + (-1.16631e-38, "1", "0" * 8, "1" * 7, "largest_denorm_neg"), + # smallest denormal + (9.18355e-41, "0", "0" * 8, "0" * 6 + "1", "smallest_denorm"), + (-9.18355e-41, "1", "0" * 8, "0" * 6 + "1", "smallest_denorm_neg"), + # positive and negative value + (30.0, "0", "10000011", "1" * 3 + "0" * 4, "random_pos"), + (-24.0, "1", "10000011", "1" + "0" * 6, "random_neg"), + ], + torch.float16: [ + # zero and neg zero + (0.0, "0", "0" * 5, "0" * 10, "zero"), + (-0.0, "1", "0" * 5, "0" * 10, "zero_neg"), + # special values + (float("nan"), "0", "1" * 5, "1" + "0" * 9, "nan"), + (float("inf"), "0", "1" * 5, "0" * 10, "inf"), + (float("-inf"), "1", "1" * 5, "0" * 10, "inf_neg"), + # values below checked with https://en.wikipedia.org/wiki/Half-precision_floating-point_format # noqa: E501 + # largest normal + (65504, "0", "1" * 4 + "0", "1" * 10, "largest_normal"), + (-65504, "1", "1" * 4 + "0", "1" * 10, "largest_normal_neg"), + # smallest normal + (0.00006103515625, "0", "0" * 4 + "1", "0" * 10, "smallest_normal"), + ( + -0.00006103515625, + "1", + "0" * 4 + "1", + "0" * 10, + "smallest_normal_neg", + ), # noqa: E501 + # largest denormal + (0.000060975552, "0", "0" * 5, "1" * 10, "largest_denorm"), + (-0.000060975552, "1", "0" * 5, "1" * 10, "largest_denorm_neg"), + # smallest denormal + (0.000000059604645, "0", "0" * 5, "0" * 9 + "1", "smallest_denorm"), + ( + -0.000000059604645, + "1", + "0" * 5, + "0" * 9 + "1", + "smallest_denorm_neg", + ), # noqa: E501 + # positive and negative value + (30.0, "0", "10011", "1" * 3 + "0" * 7, "random_pos"), + (-24.0, "1", "10011", "1" + "0" * 9, "random_neg"), + ], + torch.float8_e4m3fn: [ + # zero and neg zero + (0.0, "0", "0000", "000", "zero"), + (-0.0, "1", "0000", "000", "zero_neg"), + # special values + # note: no pos or neg inf + (float("nan"), "0", "1111", "111", "nan"), + # values below checked with https://arxiv.org/pdf/2209.05433.pdf, Table 1 # noqa: E501 + # largest normal + (448.0, "0", "1111", "110", "largest_normal"), + (-448.0, "1", "1111", "110", "largest_normal_neg"), + # smallest normal + (2**-6, "0", "0001", "000", "smallest_normal"), + (-(2**-6), "1", "0001", "000", "smallest_normal_neg"), + # largest denormal + (0.875 * 2**-6, "0", "0000", "111", "largest_denormal"), + (-0.875 * 2**-6, "1", "0000", "111", "largest_denormal_neg"), + # smallest denormal + (2**-9, "0", "0000", "001", "smallest_denormal"), + (-(2**-9), "1", "0000", "001", "smallest_denormal_neg"), + # positive and negative value + (30.0, "0", "1011", "111", "random_pos"), + (-24.0, "1", "1011", "100", "random_neg"), + ], + torch.float8_e5m2: [ + # zero and neg zero + (0.0, "0", "00000", "00", "zero"), + (-0.0, "1", "00000", "00", "zero_neg"), + # special values + (float("nan"), "0", "11111", "11", "nan"), + (float("inf"), "0", "11111", "00", "inf"), + (float("-inf"), "1", "11111", "00", "inf_neg"), + # values below checked with https://arxiv.org/pdf/2209.05433.pdf, Table 1 # noqa: E501 + # largest normal + (57344.0, "0", "11110", "11", "largest_normal"), + (-57344.0, "1", "11110", "11", "largest_normal_neg"), + # smallest normal + (2**-14, "0", "00001", "00", "smallest_normal"), + (-(2**-14), "1", "00001", "00", "smallest_normal_neg"), + # largest denormal + (0.75 * 2**-14, "0", "00000", "11", "largest_denormal"), + (-0.75 * 2**-14, "1", "00000", "11", "largest_denormal_neg"), + # smallest denormal + (2**-16, "0", "00000", "01", "smallest_denormal"), + (-(2**-16), "1", "00000", "01", "smallest_denormal_neg"), + # positive and negative value + (32.0, "0", "10100", "00", "random_pos"), + (-24.0, "1", "10011", "10", "random_neg"), + ], +} + +# values for fp4_e2m1, as defined in the OCP spec for MXFP4 +# other than the sign, there are only 8 values, so just create +# the table by hand +# formula norm: sign * (2 ** (exp - 1)) * 1.x +# formula denorm: sign * (2 ** (exp - 1 + 1)) * 0.x +# format: val, formula, s, e, m, val, note +float4_e2m1_interesting_values = [ + (0, "1.0 * 2^0 * 0.0", "0", "00", "0", "zero"), + # same as largest denormal, there is only one + ( + 0.5, + "1.0 * 2^0 * 0.5", + "0", + "00", + "1", + "smallest_denormal", + ), # 2**0 * 0.5 # noqa: E501 + (1.0, "1.0 * 2^0 * 1.0", "0", "01", "0", "smallest_normal"), # 2**0 * 1.0 + (1.5, "1.0 * 2^0 * 1.5", "0", "01", "1", "val3"), # 2**0 * 1.5 + (2.0, "1.0 * 2^1 * 1.0", "0", "10", "0", "val4"), # 2**1 * 1.0 + (3.0, "1.0 * 2^1 * 1.5", "0", "10", "1", "val5"), # 2**1 * 1.5 + (4.0, "1.0 * 2^2 * 1.0", "0", "11", "0", "val6"), # 2**2 * 1.0 + (6.0, "1.0 * 2^2 * 1.5", "0", "11", "1", "largest_normal"), # 2**2 * 1.5 +] +float4_e2m1_neg = [] +for fp32_ref, formula, _s, e, m, label in float4_e2m1_interesting_values: + float4_e2m1_neg.append( + [-1 * fp32_ref, "-" + formula, "1", e, m, label + "_neg"] + ) # noqa: E501 +float4_e2m1_interesting_values.extend(float4_e2m1_neg) +del float4_e2m1_neg + +# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, section 5.3.2 # noqa: E501 +float6_e3m2_interesting_values = [ + (0, "1.0 * 2^-2 * 0.0", "0", "000", "00", "zero"), + (0.0625, "1.0 * 2^-2 * 0.25", "0", "000", "01", "smallest_denormal"), + (0.1875, "1.0 * 2^-2 * 0.75", "0", "000", "11", "largest_denormal"), + (0.25, "1.0 * 2^-2 * 1.0", "0", "001", "00", "smallest_normal"), + (28.0, "1.0 * 2^4 * 1.75", "0", "111", "11", "largest_normal"), +] +float6_e3m2_neg = [] +for fp32_ref, formula, _s, e, m, label in float6_e3m2_interesting_values: + float6_e3m2_neg.append( + [-1 * fp32_ref, "-" + formula, "1", e, m, label + "_neg"] + ) # noqa: E501 +float6_e3m2_interesting_values.extend(float6_e3m2_neg) +del float6_e3m2_neg + +# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, section 5.3.2 # noqa: E501 +float6_e2m3_interesting_values = [ + (0, "1.0 * 2^0 * 0.0", "0", "00", "000", "zero"), + (0.125, "1.0 * 2^0 * 0.125", "0", "00", "001", "smallest_denormal"), + (0.875, "1.0 * 2^0 * 0.875", "0", "00", "111", "largest_denormal"), + (1.0, "1.0 * 2^0 * 1.0", "0", "01", "000", "smallest_normal"), + (7.5, "1.0 * 2^2 * 1.875", "0", "11", "111", "largest_normal"), +] +float6_e2m3_neg = [] +for fp32_ref, formula, _s, e, m, label in float6_e2m3_interesting_values: + float6_e2m3_neg.append( + [ + -1 * fp32_ref, + "-" + formula, + "1", + e, + m, + label + "_neg", + ] + ) +float6_e2m3_interesting_values.extend(float6_e2m3_neg) +del float6_e2m3_neg + + +def _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype): + # test going from float to encoding + x = torch.tensor(fp_ref, dtype=dtype) + bitwidth = dtype_to_bitwidth[dtype] + s_enc, e_enc, m_enc = get_sem_bits(x, bitwidth=bitwidth) + assert s_enc_ref == s_enc + assert e_enc_ref == e_enc, f"{e_enc_ref} != {e_enc}" + assert m_enc_ref == m_enc, f"{m_enc_ref} != {m_enc}" + + # test going from encoding to float + s_i, e_i, m_f, special_value = sem_bits_to_sem_vals( + s_enc, + e_enc, + m_enc, + dtype, + ) + fp = sem_vals_to_f32(s_i, e_i, m_f, special_value) + assert_same(fp_ref, fp) + + +def get_sem_bits(x: torch.Tensor, bitwidth: int) -> Tuple[str, str, str]: + """ + Input: a tensor with a single element of the target element dtype + - for PT core dtypes, that dtype (fp32, fp16, fp8_e4m3, etc) + - for fp4_e2m1, fp6_e3m2, fp6_e2m3, not supported in this function + Output: bit strings for sign, exponent, mantissa encodings of the input + """ + assert x.numel() == 1 + s_len, e_len, m_len = dtype_to_sem_len[x.dtype] + + new_dtype = dtype_to_int_dtype[x.dtype] + x = x.view(new_dtype) + np_res = get_bits(x) + if bitwidth == 4: + # TODO(future): clean up this fp4 codepath + offset = 4 + s, e, m = ( + np_res[offset], + np_res[offset + s_len : (offset + s_len + e_len)], # noqa: E203 + np_res[(offset + s_len + e_len) :], # noqa: E203 + ) + else: + s, e, m = ( + np_res[0], + np_res[s_len : (s_len + e_len)], # noqa: E203 + np_res[(s_len + e_len) :], # noqa: E203 + ) + assert len(s) == s_len + assert len(e) == e_len + assert len(m) == m_len + return s, e, m + + +def exp_encoding_to_exp(exp_bit_str: str, dtype): + """ + Input: bit string of exponent for dtype + Output: integer representation of exponent + """ + exp_biased = int(exp_bit_str, 2) + exp_bias = dtype_to_exp_bias[dtype] + exp_unbiased = exp_biased - exp_bias + + # for denormalized values, increment exponent back + # up by one + if all(b == "0" for b in exp_bit_str): + exp_unbiased += 1 + + return exp_unbiased + + +def sem_bits_to_sem_vals(s_enc, e_enc, m_enc, dtype): + """ + Input: encodings of sign, exponent, mantissa for dtype + Output: integer sign, integer exponent, float32 mantissa, special value + + Supported dtypes: PT core dtypes and fp6_e3m2 and fp6_e2m3 + Not supported dtypes: fp4 + + If special value is filled out, sem are none + If sem are filled out, special value is none + """ + sign = 1 if s_enc == "0" else -1 + + # handle special values + if all(bit == "1" for bit in e_enc): + dtypes = ( + torch.float32, + torch.bfloat16, + torch.float16, + torch.float8_e5m2, + ) + if dtype in dtypes: + if all(bit == "0" for bit in m_enc): + if s_enc == "0": + return None, None, None, float("inf") + else: + return None, None, None, float("-inf") + else: + return None, None, None, float("nan") + elif dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): + # no special values in f6 dtypes + pass + else: + assert dtype is torch.float8_e4m3fn + # 1. float8_e4m3fn does not have infinity + # 2. float8_e4m3fn only sets {s}.{1111}.{111} for nan + if all(b == "1" for b in e_enc + m_enc): + return None, None, None, float("nan") + + exponent = exp_encoding_to_exp(e_enc, dtype) + + is_zero = all(b == "0" for b in e_enc + m_enc) + is_denormal = (not is_zero) and all(b == "0" for b in e_enc) + is_normal = not is_zero and not is_denormal + + if is_zero: + return sign, exponent, 0.0, None + + mantissa = 1.0 if is_normal else 0.0 + cur_pow_2 = -1 + for m_bit in m_enc: + mantissa += int(m_bit) * pow(2, cur_pow_2) + cur_pow_2 -= 1 + return sign, exponent, mantissa, None + + +def sem_vals_to_f32(s_i, e_i, m_f, special_value): + """ + Input: integer sign, integer exponent, float32 mantissa, special value + Output: float32 value + """ + if special_value is not None: + return special_value + f = s_i * pow(2, e_i) * m_f + return f + + +def sem_vals_to_formula(s_i, e_i, m_f, special_value): + """ + Input: integer sign, integer exponent, float32 mantissa, special value + Output: formula to get the float32 value + """ + if special_value is not None: + return special_value + return f"{s_i} * 2^{e_i} * {m_f}" + + +def assert_same(fp1, fp2): + if math.isnan(fp1): + assert math.isnan(fp2) + elif math.isinf(fp1): + if fp1 > 0: + assert math.isinf(fp2) and fp2 > 0 + else: + assert math.isinf(fp2) and fp2 < 0 + else: + assert (abs(fp2 - fp1) / (fp1 + 1e-20)) - 1 < 1e-12, f"{fp2} != {fp1}" + + +def run(dtype): + print("dtype", dtype) + + headers = ["orig_val", "formula", "s_enc", "e_enc", "m_enc", "note"] + results = [] + + if dtype == DTYPE_FP4: + results = float4_e2m1_interesting_values + elif dtype == DTYPE_FP6_E3M2: + results = float6_e3m2_interesting_values + elif dtype == DTYPE_FP6_E2M3: + results = float6_e2m3_interesting_values + else: + interesting_values = dtype_to_interesting_values[dtype] + for row in interesting_values: + fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, notes = row + + # test that things still work + _assert_equals(fp_ref, s_enc_ref, e_enc_ref, m_enc_ref, dtype) + + # create the formula + s_i, e_i, m_f, special_value = sem_bits_to_sem_vals( + s_enc_ref, e_enc_ref, m_enc_ref, dtype + ) + formula = sem_vals_to_formula(s_i, e_i, m_f, special_value) + + # create the table row + results.append( + [ + fp_ref, + formula, + s_enc_ref, + e_enc_ref, + m_enc_ref, + notes, + ] + ) + + print(tabulate.tabulate(results, headers=headers)) + print("\n") + + +if __name__ == "__main__": + for dtype in ( + torch.float, + torch.bfloat16, + torch.float16, + torch.float8_e4m3fn, + torch.float8_e5m2, + DTYPE_FP6_E3M2, + DTYPE_FP6_E2M3, + DTYPE_FP4, + ): + run(dtype) diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py new file mode 100644 index 0000000000..c429eb57d4 --- /dev/null +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Defines the UX for converting a model to use mx weights + +For now, this is a module swap for speed of iteration. + +Eventually we plan to move this to a tensor subclass weight wrapper for +inference, and to a tensor subclass weight wrapper + module hooks for training. +""" + +import torch +import torch.nn.functional as F + +from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx + + +@torch._dynamo.allow_in_graph +class NoopFwToMXBw(torch.autograd.Function): + """ + Forward: no-op + Backward: cast grad to MX + """ + + @staticmethod + def forward(ctx, x, elem_dtype, block_size): + ctx.elem_dtype = elem_dtype + ctx.block_size = block_size + return x + + @staticmethod + def backward(ctx, g): + scale, data = to_mx(g, ctx.elem_dtype, ctx.block_size) + return ( + MXTensor(scale, data, ctx.elem_dtype, ctx.block_size, g.dtype), + None, + None, + ) + + +class MXLinear(torch.nn.Linear): + """ + Linear layer with the compute happening in emulate MX. Currently the MX + matmul is emulated since there is no hardware support yet. Activations, + weights and grads are casted to MX and back to high precision for each + matmul. + """ + + @classmethod + @torch.no_grad() + def from_float(cls, mod, elem_dtype, block_size): + mod.__class__ = MXLinear + mod.elem_dtype = elem_dtype + mod.block_size = block_size + return mod + + def forward(self, x): + x_mx = MXTensor.to_mx(x, self.elem_dtype, self.block_size) + w_mx = MXTensor.to_mx(self.weight, self.elem_dtype, self.block_size) + y = F.linear(x_mx, w_mx, self.bias) + y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size) + return y + + +class MXInferenceLinear(torch.nn.Linear): + """ + Inference version of MXLinear, with the weight pre-quantized to MX. + """ + + @classmethod + @torch.no_grad() + def from_float(cls, mod, elem_dtype, block_size): + with torch.device("meta"): + super_kwargs = { + "in_features": mod.in_features, + "out_features": mod.out_features, + "bias": False, + } + new_mod = cls(**super_kwargs) + # TODO(future PR): set to new_mod.weight directly, will need to work + # through some errors + new_mod.weight_mx = MXTensor.to_mx( + mod.weight.t().contiguous(), elem_dtype, block_size=block_size + ).t() + new_mod.bias = mod.bias + new_mod.elem_dtype = elem_dtype + return new_mod + + @torch.no_grad() + def forward(self, x): + w_hp = self.weight_mx.to_dtype(x.dtype) + y = F.linear(x, w_hp, self.bias) + return y + + +def replace_with_custom_fn_if_matches_filter( + model, replacement_fn, filter_fn, cur_fqn="" +) -> None: + """ + For each `child` in `model`, replaces it with `replacement_fn(child)` + if `filter_fn(child)` is `True` + """ + name_to_child = dict(model.named_children()) + for name, child in name_to_child.items(): + if cur_fqn == "": + new_fqn = name + else: + new_fqn = f"{cur_fqn}.{name}" + if filter_fn(child, new_fqn): + new_child = replacement_fn(child) + setattr(model, name, new_child) + else: + replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, new_fqn + ) + + +def _is_linear(mod, fqn): + return isinstance(mod, torch.nn.Linear) + + +def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None): + if filter_fn is None: + combined_filter_fn = _is_linear + else: + + def __fn(mod, fqn): + return _is_linear(mod, fqn) and filter_fn(mod, fqn) + + combined_filter_fn = __fn + replace_with_custom_fn_if_matches_filter( + model, + lambda mod: MXLinear.from_float(mod, elem_dtype, block_size), + combined_filter_fn, + ) + + +def swap_linear_with_mx_inference_linear( + model, + elem_dtype, + block_size, + filter_fn=None, +): + if filter_fn is None: + combined_filter_fn = _is_linear + else: + + def __fn(mod, fqn): + return _is_linear(mod, fqn) and filter_fn(mod, fqn) + + combined_filter_fn = __fn + replace_with_custom_fn_if_matches_filter( + model, + lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size), + combined_filter_fn, + ) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py new file mode 100644 index 0000000000..365329cf14 --- /dev/null +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict + +import torch +from torch.utils._pytree import tree_map + +from torchao.prototype.mx_formats.constants import DTYPE_FP4 +from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 + MXTensor, + tensor_size_hp_to_fp4x2, +) + +aten = torch.ops.aten + +MX_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops): + """Register aten ops to the mx op table""" + + def decorator(func): + for op in aten_ops: + MX_OPS_TABLE[op] = func + return func + + return decorator + + +@implements([aten.detach.default]) +def mx_desugar_op(aten_op, args, kwargs=None): + old = args[0] + new_data = aten_op(old._data, *args[1:], **kwargs) + new = MXTensor( + old._scale_e8m0, + new_data, + old._elem_dtype, + old._block_size, + old._orig_dtype, + ) + return new + + +@implements([aten.mm.default, aten.matmul.default]) +def mx_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + assert isinstance(a, MXTensor) and isinstance(b, MXTensor) + a_hp = a.to_dtype(a._orig_dtype) + b_hp = b.to_dtype(b._orig_dtype) + res = aten_op(a_hp, b_hp) + return res + + +@implements([aten.addmm.default]) +def mx_addmm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + c = args[2] + assert isinstance(b, MXTensor) and isinstance(c, MXTensor) + b_hp = b.to_dtype(b._orig_dtype) + c_hp = c.to_dtype(c._orig_dtype) + res = aten_op(a, b_hp, c_hp) + return res + + +@implements([aten.t.default]) +def mx_t(aten_op, args, kwargs=None): + # For now, only transpose(input, 0, 1) is supported. + old = args[0] + new = MXTensor( + old._scale_e8m0, + old._data.t(), + old._elem_dtype, + old._block_size, + old._orig_dtype, + ) + return new + + +@implements([aten.sum.dim_IntList]) +def mx_cast_up_op(aten_op, args, kwargs=None): + """Be careful with this function, this is a "fallback" op that + casts the output of the op to the original precision. And performs the op. + + We currently need this to support the backward for admmm bias. + "addmm" -> out + "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" + """ + + def unwrap(x): + if isinstance(x, MXTensor): + return x.to_dtype(x._orig_dtype) + return x + + new_args = tree_map(unwrap, args) + new_kwargs = tree_map(unwrap, kwargs) + return aten_op(*new_args, **new_kwargs) + + +@implements([aten.view.default]) +def mx_view_op(aten_op, args, kwargs=None): + data = args[0]._data + new_size = args[1] + if args[0]._elem_dtype == DTYPE_FP4: + # special case fp4 as we pack two elements per byte + new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) + new_data = aten_op(data, new_size, *args[2:], **kwargs) + return MXTensor( + args[0]._scale_e8m0, + new_data, + args[0]._elem_dtype, + args[0]._block_size, + args[0]._orig_dtype, + ) + + +@implements([aten._to_copy.default]) +def autocast_to_copy(aten_op, args, kwargs=None): + """This gets called when running matmul under autocast + when the input is a MXTensor, presenting as a fp32 + tensor. + """ + assert isinstance(args[0], MXTensor) + # print('before', args[0], args[0].dtype, args[0]._orig_dtype) + assert ( + len(kwargs) == 1 and "dtype" in kwargs + ), "Only support dtype kwarg for autocast" + assert kwargs["dtype"] in { + torch.float16, + torch.bfloat16, + }, "Only support floating point conversion for autocast w/ MXTensor" + res = MXTensor( + args[0]._scale_e8m0, + args[0]._data, + args[0]._elem_dtype, + args[0]._block_size, + kwargs["dtype"], + ) + # print('after', res, res.dtype, res._orig_dtype) + return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py new file mode 100644 index 0000000000..e8dd80ae0c --- /dev/null +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -0,0 +1,411 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Defines the tensor subclasses to represent the MX format spec from +https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + +Exponent E8M0 encoding details (OCP spec section 5.4.1): + * bias: 127 + * supported exponent range: -127 to 127 + * infinities: N/A + * NaN: 11111111 + * Zeros: N/A +""" + +from typing import Dict, Union + +import torch + +import torchao.prototype.mx_formats.config as config +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + E8M0_EXPONENT_BIAS, + E8M0_EXPONENT_NAN_VAL, + F32_MIN_NORMAL, + F4_E2M1_MAX, + F4_E2M1_MAX_POW2, + F6_E2M3_MAX, + F6_E2M3_MAX_POW2, + F6_E3M2_MAX, + F6_E3M2_MAX_POW2, + F8E4M3_MAX, + F8E4M3_MAX_POW2, + F8E5M2_MAX, + F8E5M2_MAX_POW2, + SUPPORTED_ELEM_DTYPES, +) + +from torchao.prototype.mx_formats.custom_cast import ( + f32_to_f4_unpacked, + f32_to_f6_e2m3_unpacked, + f32_to_f6_e3m2_unpacked, + f4_unpacked_to_f32, + f6_e2m3_unpacked_to_f32, + f6_e3m2_unpacked_to_f32, + pack_uint4, + triton_f4_to_scaled_bf16, + unpack_uint4, +) + + +def to_mx( + data_hp: torch.Tensor, + elem_dtype: Union[torch.dtype, str], + block_size: int, +): + """ + Takes a high precision tensor and converts to MX scale and raw data, in + naive layout (scale and raw data are separate tensors). + """ + + assert data_hp.dtype in ( + torch.bfloat16, + torch.float, + ), f"{data_hp.dtype} is not supported yet" + assert data_hp.numel() % block_size == 0, "unsupported" + assert data_hp.is_contiguous(), "unsupported" + assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" + + # calculate the scale in e8m0 format + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(-1, block_size) + + # find max value of the data + max_abs = torch.amax(torch.abs(data_hp), 1) + + # Add an epsilon to prevent the log2 function call for returning -inf + # where the values are zero. + eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) + + # Find largest power of 2 less than or equal to max_abs. + largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps)) + + # Set X to be the largest power-of-two less than or equal to + # max_abs(v), divided by the largest power of two representable + # in the element data type + if elem_dtype == torch.float8_e4m3fn: + target_max_pow2 = F8E4M3_MAX_POW2 + elif elem_dtype == torch.float8_e5m2: + target_max_pow2 = F8E5M2_MAX_POW2 + elif elem_dtype == DTYPE_FP6_E2M3: + target_max_pow2 = F6_E2M3_MAX_POW2 + elif elem_dtype == DTYPE_FP6_E3M2: + target_max_pow2 = F6_E3M2_MAX_POW2 + elif elem_dtype == DTYPE_FP4: + target_max_pow2 = F4_E2M1_MAX_POW2 + else: + raise AssertionError("unsupported") + scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2 + + # Clamp to exponents that can be represented in e8m0 + scale_e8m0_unbiased = torch.clamp( + scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + ) + + # Create the biased e8m0 representation and cast it to 8 bits + scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS + scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8) + + # Conversion to torch.uint8 sets NaN values to 0, fix this by + # explicitly setting known NaN values to 255 + scale_e8m0_biased = torch.where( + torch.isnan(scale_e8m0_unbiased), + E8M0_EXPONENT_NAN_VAL, + scale_e8m0_biased, + ) + + # For now, calculate the scale in floating point. + # TODO(future) audit if there is a need to bit shift exponents instead. + scale_fp = torch.pow( + torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device), + scale_e8m0_unbiased, + ) + + # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the + # float32 denormal range. For now, manually adjust the fp scale. This is + # relevant if all of the incoming block values are zeroes. + # See https://github.com/pytorch/pytorch/issues/125557 for details. + # Note: it would be more correct to set the minimum to 2**-127, but this + # does not work in triton either as it looks like subnormal value handling + # has some gaps. So, for now just set to the minimum normal value. + scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL) + + # scale and saturated cast the data elements to max of target dtype + if elem_dtype == torch.float8_e4m3fn: + max_pos = F8E4M3_MAX + elif elem_dtype == torch.float8_e5m2: + max_pos = F8E5M2_MAX + elif elem_dtype == DTYPE_FP6_E2M3: + max_pos = F6_E2M3_MAX + elif elem_dtype == DTYPE_FP6_E3M2: + max_pos = F6_E3M2_MAX + elif elem_dtype == DTYPE_FP4: + max_pos = F4_E2M1_MAX + else: + raise AssertionError("unsupported") + data_lp = torch.clamp( + data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos + ) + data_lp = data_lp.reshape(orig_shape) + + # cast to target dtype + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + data_lp = data_lp.to(elem_dtype) + elif elem_dtype == DTYPE_FP6_E2M3: + data_lp = f32_to_f6_e2m3_unpacked(data_lp) + elif elem_dtype == DTYPE_FP6_E3M2: + data_lp = f32_to_f6_e3m2_unpacked(data_lp) + elif elem_dtype == DTYPE_FP4: + data_lp = f32_to_f4_unpacked(data_lp) + data_lp = pack_uint4(data_lp) + else: + raise AssertionError("unsupported") + + return scale_e8m0_biased, data_lp + + +def get_fp_scale(scale_e8m0): + s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS + # TODO(later): it would be nice if there was a way to do the 2^x operation + # in PyTorch without creating a tensor of twos + two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) + # pow(two, s_offset) can be out of range of floating point formats. + # TODO(later): handle this for float16 if we decide to support float16 + # scales. + s_fp = torch.pow(two, s_offset) + + # If a block exponent was 255, set values of that block to NaN + s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan")) + + return s_fp + + +def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): + orig_shape = data_lp.shape + is_transposed = not data_lp.is_contiguous() + # if the underlying data is transposed, convert to row major before + # unpacking and unscaling + if is_transposed: + data_lp = data_lp.t() + assert data_lp.is_contiguous() + orig_shape = (orig_shape[1], orig_shape[0]) + + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + data_hp = data_lp.to(target_dtype) + elif elem_dtype == DTYPE_FP6_E2M3: + data_hp = f6_e2m3_unpacked_to_f32(data_lp) + data_hp = data_hp.to(target_dtype) + elif elem_dtype == DTYPE_FP6_E3M2: + data_hp = f6_e3m2_unpacked_to_f32(data_lp) + data_hp = data_hp.to(target_dtype) + elif elem_dtype == DTYPE_FP4: + if config.use_fp4_custom_triton_dequant_kernel: + data_hp_rescaled = triton_f4_to_scaled_bf16( + data_lp, + scale_e8m0, + block_size, + ) + if is_transposed: + data_hp_rescaled = data_hp_rescaled.t() + return data_hp_rescaled.to(target_dtype) + else: + # fp4 + f4_unpacked = unpack_uint4(data_lp) + # for now we only have a cast to f32 + # TODO(future PR): add cast directly to bf16 + f32 = f4_unpacked_to_f32(f4_unpacked) + data_hp = f32.to(target_dtype) + # manually adjust shape to account for the unpacking + # TODO(future PR): clean up the shape code and remove the hack + # below + orig_shape = (*orig_shape[:-1], orig_shape[-1] * 2) + else: + raise AssertionError("unsupported") + + data_hp = data_hp.reshape(-1, block_size) + s_fp = get_fp_scale(scale_e8m0).reshape(-1, 1).to(target_dtype) + data_hp = data_hp * s_fp + data_hp = data_hp.reshape(orig_shape) + + # if we converted to row-major before unscaling convert back + if is_transposed: + data_hp = data_hp.t() + + return data_hp + + +def tensor_size_hp_to_fp4x2(orig_size, is_contiguous): + new_size = orig_size + if is_contiguous: + new_size = [*list(new_size[:-1]), new_size[-1] // 2] + else: + new_size = [new_size[0] // 2, *list(new_size[1:])] + return new_size + + +def tensor_size_fp4x2_to_hp(orig_size, is_contiguous): + new_size = orig_size + if is_contiguous: + new_size = [*list(new_size[:-1]), new_size[-1] * 2] + else: + new_size = [new_size[0] * 2, *list(new_size[1:])] + return new_size + + +@torch._dynamo.allow_in_graph +class ToMXConstrFunc(torch.autograd.Function): + """ + Differentiable cast to MX, no-op in backward + """ + + @staticmethod + def forward(ctx, data_hp, elem_dtype, block_size): + scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size) + return MXTensor( + scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype + ) + + @staticmethod + def backward(ctx, g): + return g, None, None + + +@torch._dynamo.allow_in_graph +class FromMXConstrFunc(torch.autograd.Function): + """ + Differentiable cast from MX, no-op in backward + """ + + @staticmethod + def forward(ctx, tensor_lp, target_dtype): + return to_dtype( + tensor_lp._data, + tensor_lp._scale_e8m0, + tensor_lp._elem_dtype, + tensor_lp._block_size, + target_dtype, + ) + + @staticmethod + def backward(ctx, g): + return g, None, None + + +class MXTensor(torch.Tensor): + def __new__( + cls, + scale_e8m0_bits, + data_bits, + elem_dtype, + block_size, + orig_dtype, + ): + new_size = data_bits.size() + if elem_dtype == DTYPE_FP4: + # set the tensor size to what it would be without 2x4 packing + new_size = tensor_size_fp4x2_to_hp( + new_size, + data_bits.is_contiguous(), + ) + self = torch.Tensor._make_wrapper_subclass( + cls, + new_size, + dtype=orig_dtype, + device=data_bits.device, + ) + assert scale_e8m0_bits.dtype == torch.uint8, "unsupported" + assert len(scale_e8m0_bits.shape) == 1, "unsupported" + assert data_bits.dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint8, + ), "unsupported" + if elem_dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + ): + target_numel = scale_e8m0_bits.numel() * block_size + elif elem_dtype == DTYPE_FP4: + assert data_bits.dtype is torch.uint8 # fp4 + target_numel = scale_e8m0_bits.numel() * block_size / 2 + else: + raise AssertionError("unsupported") + if not issubclass( + torch._subclasses.fake_tensor.FakeTensor, + type(data_bits), + ): + # this check is sometimes broken for FakeTensor + # TODO investigate + assert ( + target_numel == data_bits.numel() + ), f"{target_numel} != {data_bits.numel()}" + + # `_scale_e8m0` has rank 1 and applies to a row-major memory layout of + # `_data` + self._scale_e8m0 = scale_e8m0_bits + self._data = data_bits + self._elem_dtype = elem_dtype + self._block_size = block_size + self._orig_dtype = orig_dtype + return self + + def __repr__(self): + # TODO better elem dtype print for fp4 + return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}" # noqa: E501 + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # avoid circular dependency + from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE + + if func in MX_OPS_TABLE: + return MX_OPS_TABLE[func](func, args, kwargs) + + raise NotImplementedError(f"{func} not implemented") + + def to_dtype(self, target_dtype): + return FromMXConstrFunc.apply(self, target_dtype) + + @staticmethod + @torch._dynamo.allow_in_graph + def to_mx( + data_hp: torch.Tensor, + elem_dtype: Union[torch.dtype, str], + block_size: int, + ): + return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size) + + def __tensor_flatten__(self): + ctx = { + "_elem_dtype": self._elem_dtype, + "_block_size": self._block_size, + "_orig_dtype": self._orig_dtype, + } + return ["_scale_e8m0", "_data"], ctx + + @staticmethod + def __tensor_unflatten__( + inner_tensors: Dict, + metadata, + outer_size, + outer_stride, + ): + return MXTensor( + inner_tensors["_scale_e8m0"], + inner_tensors["_data"], + metadata["_elem_dtype"], + metadata["_block_size"], + metadata["_orig_dtype"], + ) + + # Do not force the MXTensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py new file mode 100644 index 0000000000..3bda663c1c --- /dev/null +++ b/torchao/prototype/mx_formats/utils.py @@ -0,0 +1,7 @@ +import torch + + +def compute_error(x, y): + Ps = torch.norm(x) # noqa: TOR101 + Pn = torch.norm(x - y) # noqa: TOR101 + return 20 * torch.log10(Ps / Pn)