Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP6-LLM clean up (again) #339

Merged
merged 27 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7fabc8f
override load from state dict
gau-nernst May 30, 2024
1c08568
fix prefix
gau-nernst May 31, 2024
d89e9da
migrate to mx primitive
gau-nernst Jun 1, 2024
6f84293
remove unneeded code
gau-nernst Jun 1, 2024
571910b
comment out test
gau-nernst Jun 1, 2024
73b8354
Merge branch 'pytorch:main' into improve_fp6_llm_linear
gau-nernst Jun 1, 2024
4e2964a
remove
gau-nernst Jun 2, 2024
adefee8
add rounding test for f6_e3m2
gau-nernst Jun 2, 2024
f8268f0
update tests
gau-nernst Jun 2, 2024
ebbff67
remove openmp flag
gau-nernst Jun 2, 2024
25e4be7
update benchmark script
gau-nernst Jun 2, 2024
21dfc60
test negative number
gau-nernst Jun 3, 2024
64e24f7
remove qtorch dep
gau-nernst Jun 3, 2024
6d6f5dd
fix type casting
gau-nernst Jun 3, 2024
474ebc2
add view
gau-nernst Jun 3, 2024
e64fbac
Merge branch 'main' into improve_fp6_llm_linear
gau-nernst Jun 4, 2024
3231039
Merge branch 'main' into improve_fp6_llm_linear
gau-nernst Jun 4, 2024
d7ec248
Merge branch 'pytorch:main' into improve_fp6_llm_linear
gau-nernst Jun 5, 2024
86562c2
Merge branch 'pytorch:main' into improve_fp6_llm_linear
gau-nernst Jun 7, 2024
60c8e6a
Merge branch 'main' into improve_fp6_llm_linear
gau-nernst Jun 9, 2024
509217c
fix strange pytest behavior
gau-nernst Jun 9, 2024
11dcba3
only skip tests requiring PyTorch 2.4
gau-nernst Jun 9, 2024
6f8e7e9
remove weight loading magic
gau-nernst Jun 9, 2024
62412fb
Merge branch 'main' into improve_fp6_llm_linear
gau-nernst Jun 10, 2024
f454f4d
fix typing tuple
gau-nernst Jun 10, 2024
fa38572
fix list typing
gau-nernst Jun 10, 2024
156063a
Merge branch 'main' into improve_fp6_llm_linear
gau-nernst Jun 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,3 @@ tabulate # QOL for printing tables to stdout

# Custom CUDA Extensions
ninja

# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
qtorch
2 changes: 0 additions & 2 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ torchao.dtypes

to_nf4
UInt4Tensor
to_float6_e3m2
from_float6_e3m2

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = ["-fopenmp"]
extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down
134 changes: 0 additions & 134 deletions test/dtypes/test_float6_e3m2.py

This file was deleted.

29 changes: 27 additions & 2 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

torch.manual_seed(0)

Expand Down Expand Up @@ -322,6 +320,7 @@ def test_fp4_pack_unpack():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -331,6 +330,7 @@ def test_fp4_triton_unscaled_cast():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down Expand Up @@ -386,3 +386,28 @@ def test_fp6_values(dtype_name):
else:
raise AssertionError("unsupported")
torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")),
]
)
@pytest.mark.parametrize(
"f32_val,f6_e3m2_enc",
[
(29.0, 0b011111), # normal round down
(26.0, 0b011110), # normal round to nearest even
(0.1251, 0b000010), # subnormal round down
(0.0314, 0b000001), # subnormal round up
(0.03, 0b000000), # underflow
]
)
def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(f32_val, device=device))
assert f6_e3m2_unpacked.item() == f6_e3m2_enc

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)
27 changes: 17 additions & 10 deletions test/quantization/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
parametrize,
run_tests,
)
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm
from torchao.ops import prepack_fp6_weight
from torchao.quantization.fp6_llm import (
to_tc_float6_e3m2,
from_tc_float6_e3m2,
_to_tc_float6_e3m2_ref,
Fp6LlmLinear,
convert_fp6_llm,
)
from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand All @@ -20,9 +25,9 @@ class TestFp6LlmLinear(TestCase):
def test_to_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)

expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8)
expected = _to_tc_float6_e3m2_ref(x)
actual = to_tc_float6_e3m2(x)
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1))
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_compile(self, device):
Expand All @@ -35,18 +40,20 @@ def test_to_tc_float6_e3m2_compile(self, device):
@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape)
# quantize and dequantize so that the values are exactly representable in FP6
x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x))

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x))
torch.testing.assert_close(actual, x)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_compile(self, device):
M, N = 256, 64
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device)
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x, M, N)
actual = torch.compile(from_tc_float6_e3m2)(x, M, N)
expected = from_tc_float6_e3m2(x)
actual = torch.compile(from_tc_float6_e3m2)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
76 changes: 10 additions & 66 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
import torchao
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
import unittest
from parameterized import parameterized
import pytest
Expand All @@ -18,94 +18,38 @@
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
class TestOps(TestCase):
def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight, fp16_scale, fp16_activation

def test_prepack_fp6_weight(self):
OC = 256
IC = 256
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.prepack_fp6_weight(fp6_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6_original(self):
OC = 256
IC = 256
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)

# the original FP16->FP6 kernel checks for overflow/underflow
fp16_weight.clip_(-28.0, 28.0)
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0

# smoke test
torchao.ops.fp16_to_fp6_original(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

# smoke test
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = act_cuda @ fp16_weight.cuda().T
fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
Expand Down
File renamed without changes.
Loading
Loading