Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP16Act-FP6Weight Linear #223

Merged
merged 44 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c742a0e
add files from fp6_llm
gau-nernst May 7, 2024
963bfa6
Merge branch 'pytorch:main' into fp6
gau-nernst May 7, 2024
4eb8be6
try to port weight packing first
gau-nernst May 7, 2024
8608664
rename
gau-nernst May 7, 2024
b7c7b28
rename fp6 weight packing
gau-nernst May 7, 2024
3c9aac7
add fp16act_fp6weight_linear
gau-nernst May 7, 2024
031379a
fix function def
gau-nernst May 7, 2024
75ca602
Merge branch 'main' into fp6
msaroufim May 7, 2024
c436c43
delete duplicate file
gau-nernst May 8, 2024
12823fe
move weight quant file
gau-nernst May 8, 2024
9180fef
rename
gau-nernst May 8, 2024
1b24424
add pytorch interface for fp6 weight dequant
gau-nernst May 8, 2024
2671c9c
add fake_fp6 to fp6
gau-nernst May 8, 2024
e61be51
move weight_quant to csrc/cuda due to cuda_fp16.h dependency
gau-nernst May 8, 2024
21acfd1
add fake_fp6_to_fp6 test
gau-nernst May 8, 2024
67fd6f8
add test for fp16act_fp6weight_linear
gau-nernst May 8, 2024
084b7e4
add test for fp6_weight_dequant
gau-nernst May 8, 2024
6d2fc3e
Fp6WeightOnlyQuantizedLinearWeight (not working yet)
gau-nernst May 8, 2024
68f2415
skip some tests, since the functions are not built w/o CUDA
gau-nernst May 8, 2024
0c78635
Merge branch 'main' into fp6
gau-nernst May 9, 2024
5989599
add the original test
gau-nernst May 9, 2024
92dfde4
implement transpose and clone so that F.linear will work
gau-nernst May 9, 2024
da1421b
remove print
gau-nernst May 9, 2024
fecd0cc
Merge branch 'main' into fp6
gau-nernst May 9, 2024
a0a53a0
remove dequantize
gau-nernst May 9, 2024
079e16b
add notes and some rename
gau-nernst May 9, 2024
06e8438
typo
gau-nernst May 9, 2024
ca45274
small cleanup
gau-nernst May 9, 2024
7a0f6e2
improve tensor subclass and add test (which is failing for torch-comp…
gau-nernst May 9, 2024
320827e
add note
gau-nernst May 9, 2024
74b8094
add note
gau-nernst May 9, 2024
c8d47c3
add qtorch as dev requirement
gau-nernst May 9, 2024
e08ba6a
update error message
gau-nernst May 9, 2024
b090b4b
add __repr__ and fix transposed issue
gau-nernst May 9, 2024
f6f93c3
add fp6 perplexity test
gau-nernst May 9, 2024
b857645
rename variables
gau-nernst May 10, 2024
e69e6db
Merge branch 'main' into fp6
gau-nernst May 10, 2024
f0eba1a
remove subclass
gau-nernst May 10, 2024
8f1ef8d
add correctness test
gau-nernst May 10, 2024
cb05d30
remove unwanted changes
gau-nernst May 10, 2024
56aefc6
add apache 2.0 notice
gau-nernst May 10, 2024
7d3a5b1
add benchmark script
gau-nernst May 10, 2024
08a95ac
add note about FP6 kernel
gau-nernst May 14, 2024
a8b4dd3
relax tolerance
gau-nernst May 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torchao
from torch.utils.benchmark import Timer
import pandas as pd
from tqdm import tqdm


def benchmark(m, k, n, splitK):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(n).half() + 0.5
fp16_activation = torch.rand(m, k).half() + 0.5

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

# need to do this since Timer cannot see torchao
def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK):
return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp6_measurement = Timer(
stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)",
globals=locals(),
).blocked_autorange()

fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
fp16_output = act_cuda @ fp16_weight.T

fp16_measurement = Timer(
stmt="act_cuda @ fp16_weight.T",
globals=locals(),
).blocked_autorange()

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3

return {
"m": m,
"k": k,
"n": n,
"fp6_latency (ms)": fp6_measurement.median * 1000,
"fp16_latency (ms)": fp16_measurement.median * 1000,
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
"correct": correct,
}


if __name__ == "__main__":
# from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh
k_vals = (8192, 8192, 8192, 28672)
n_vals = (10240, 8192, 57344, 8192)

results = []

# splitK can be tuned based on m, k, n
for m, splitK_vals in tqdm([
(1, (5, 6, 7, 6)),
(2, (5, 6, 7, 6)),
(4, (5, 6, 7, 6)),
(8, (5, 6, 7, 6)),
# (16, (5, 6, 7, 6)),
# (64, (5, 6, 7, 6)),
# (128, (5, 3, 3, 3)),
# (256, (4, 3, 2, 3)),
# (512, (2, 5, 2, 4)),
(1024, (1, 2, 1, 2)),
(2048, (1, 1, 1, 1)),
(4096, (1, 1, 1, 1)),
# (8192, (1, 1, 1, 1)),
# (16384, (1, 1, 1, 1)),
]):
for n, k, splitK in zip(n_vals, k_vals, splitK_vals):
results.append(benchmark(m, n, k, splitK))

df = pd.DataFrame(results)
df.to_csv("fp6_benchmark_results.csv", index=False)
print(df.to_markdown(index=False))
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def get_extensions():

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))

if use_cuda:
sources += cuda_sources
Expand Down
93 changes: 93 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torchao
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
import unittest
from parameterized import parameterized


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
Expand Down Expand Up @@ -42,6 +43,98 @@ def test_nms(self):
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils)

def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight, fp16_scale, fp16_activation

def test_prepack_fp6_weight(self):
OC = 256
IC = 256
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.prepack_fp6_weight(fp6_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6(self):
OC = 256
IC = 256

# in this fp6, we use 3 bits for exponent and 2 bits for mantissa
# also, we don't have nan/inf
fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11
fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number)
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
fp16_weight.clip_(-fp6_absmax, fp6_absmax)
fp16_weight[fp16_weight.abs() < fp6_absmin] = 0

# smoke test
torchao.ops.fp16_to_fp6(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

# smoke test
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_weight_dequant(self):
OC = 256
IC = 256
fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
results_fp16 = act_cuda @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
assert relative_error.mean() < 1e-2


if __name__ == "__main__":
unittest.main()
90 changes: 90 additions & 0 deletions torchao/csrc/cuda/fp6_llm/configs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 FP6-LLM authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h

#ifndef CONFIGS_H
#define CONFIGS_H

//#define DEBUG_MODE
#define PIPELINE_LEVEL_GMEM 2
#define PIPELINE_LEVEL_SMEM 2 // only support 2

/************************ Hardware Parameters ************************/
#define WARP_SIZE 32
#define REG_BIT_WIDTH 32
// mma: M=16 K=16 N=8
#define MMA_8 8
#define MMA_16 16
// for memory access
#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ...
#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16

/******************** Register Allocation For GEMM ********************/
#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation
/********************** Memory Padding Parameters **********************/
// Eliminating bank-conflict
#define PADDING_BYTES_16 16 // Padding 16 bytes each column
#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B
#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C
/************************* WARP Tiling part-1 *************************/
#define WARP_ROW_MMA_TENSORS 4
#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64
#define WARP_K_MMA_TENSORS 4
#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64
template<int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
struct TilingConfig {
// Depending on "n" dimension of the GEMM
static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
/************************* WARP Tiling part-2 *************************/
static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
/*************************Thread Block Tiling *************************/
static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
static constexpr int TILE_K = WARP_K;
/********************** #Thread per Thread Block **********************/
static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
/******************************* Others *******************************/
static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2
static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4
};

/************************ General Config for FP6-LLM **********************/
#define WEIGHT_FRAG1_BIT_WIDTH 2
#define WEIGHT_FRAG2_BIT_WIDTH 4
#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6
//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256
/*************************** 64*64 Weghts of A WARP *************************/
#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64
#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration
#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration
#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB.
#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB.
/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/
#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64
#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128
/******************** Register Allocation For QUANTIZED DATA ******************/
#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128
#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8
#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16
/******************** Register Allocation For QUANT Scales ******************/
#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers
#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread



#endif // CONFIGS_H
Loading
Loading