Skip to content

Commit

Permalink
support Marlin W4A8 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
HandH1998 committed Nov 14, 2024
1 parent 39f16f4 commit 2690ff4
Show file tree
Hide file tree
Showing 16 changed files with 2,832 additions and 9 deletions.
64 changes: 64 additions & 0 deletions benchmarks/benchmark_marlin_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import marlin_qqq_gemm
from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq
from tqdm import tqdm


def get_problem(m, n, k, groupsize=-1):
if groupsize == -1:
groupsize = k
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev)
s_tok = torch.ones((m, 1), dtype=torch.float, device=dev)
if groupsize == k:
s_group = torch.tensor([], dtype=torch.half, device=dev)
else:
s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev)
s_channel = torch.ones((1, n), dtype=torch.float, device=dev)
B, s_group, s_channel = pack_to_marlin_qqq(
B, s_group, s_channel, num_bits=4, group_size=group_size
)
qqq_workspace = marlin_qqq_workspace(n)
return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace


def benchmark(m: int, k: int, n: int, group_size: int):
A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem(
m, n, k, group_size
)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds(
marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k
)

return {
"m": m,
"k": k,
"n": n,
"group_size": group_size,
"fp16_latency (ms)": fp16_time,
"marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time,
"speedup (d/s)": fp16_time / marlin_qqq_w4a8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for group_size in tqdm([-1, 128]):
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n, group_size))

df = pd.DataFrame(results)
df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False)
print(df.to_markdown(index=False))
129 changes: 129 additions & 0 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import copy

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.dtypes import MarlinQQQLayout
from torchao.quantization.marlin_qqq import (
pack_to_marlin_qqq,
unpack_from_marlin_qqq,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_and_quantize_affine_qqq,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


class MarlinQQQ(TestCase):
def setUp(self):
super().setUp()
torch.manual_seed(0)

self.input = torch.randn((64, 32, 8192), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(8192, 21504),
nn.Linear(21504, 8192),
nn.ReLU(),
nn.Linear(8192, 21504),
nn.Linear(21504, 8192),
)
.half()
.cuda()
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_marlin_qqq(self):
output_ref = self.model(self.input)
for group_size in [-1, 128]:
modelq = copy.deepcopy(self.model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=group_size,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
output = modelq(self.input)

assert torch.allclose(
output, output_ref, atol=1e-1
), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_marlin_qqq_compile(self):
model_copy = copy.deepcopy(self.model)
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
output_ref = model_copy(self.input)

for group_size in [-1, 128]:
modelq = copy.deepcopy(self.model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=group_size,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
modelq.forward = torch.compile(modelq.forward, fullgraph=True)
output = modelq(self.input)

assert torch.allclose(
output, output_ref, atol=1e-1
), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
shape = (11008, 4096)

w = torch.rand(shape, dtype=torch.float16, device="cuda")

for group_size in [-1, 128]:
# Quantize weights
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
w, num_bits, group_size
)

q_w = q_w.t()
s_group = s_group.t()
s_channel = s_channel.t()

# Test pack/unpack equivalence
q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq(
q_w, s_group, s_channel, num_bits, group_size
)
unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq(
q_w_comp,
packed_s_group,
packed_s_channel,
q_w.shape,
num_bits,
group_size,
)

assert torch.equal(
q_w, unpacked_q_w
), "Unpacked weights do not match original weights"
assert torch.equal(
s_channel, unpacked_s_channel
), "Unpacked s_channel do not match original s_channel"
assert torch.equal(
s_group, unpacked_s_group
), "Unpacked s_group do not match original s_group"


if __name__ == "__main__":
run_tests()
109 changes: 109 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
from torchao.quantization.marlin_qqq import (
marlin_qqq_workspace,
pack_to_marlin_qqq,
)
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
import pytest

if is_fbcode():
Expand Down Expand Up @@ -426,5 +431,109 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
)


MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_QQQ_K_CHUNKS = [128]
MARLIN_QQQ_N_CHUNKS = [64, 128, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
]
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(
itertools.product(
MARLIN_QQQ_BATCH_SIZE,
MARLIN_QQQ_K_CHUNKS,
MARLIN_QQQ_N_CHUNKS,
MARLIN_QQQ_SUPPORTED_NUM_BITS,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
MNK_FACTORS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
ids=str,
)
def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
int8_traits = torch.iinfo(torch.int8)
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn(
(batch_size, size_m, size_k), dtype=torch.float16, device="cuda"
)
b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda")

# Reshape input into 2D tensor
input_2d = a_input.view(-1, a_input.shape[-1])
a_input_in, a_input_out = input_2d.shape

# Quantize activations
s_a = (
input_2d.abs()
.max(dim=-1, keepdim=True)[0]
.div(int8_traits.max)
.to(torch.float32)
)
q_a = (
(input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
)

# Quantize weights
q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq(
b_weight, num_bits, group_size
)
q_w = q_w.t()
s_group = s_group.t()
s_channel = s_channel.t()
w_ref = w_ref.t()
marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq(
q_w, s_group, s_channel, num_bits, group_size
)

workspace = marlin_qqq_workspace(size_n)

# Obtains reference output
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,))

fn_inputs = (
q_a,
marlin_qqq_q_w,
s_a,
marlin_qqq_s_channel,
marlin_qqq_s_group,
workspace,
a_input_in,
size_n,
a_input_out,
)
output = torchao.ops.marlin_qqq_gemm(*fn_inputs)
output = output.reshape(a_input.shape[:-1] + (size_n,))

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.marlin_qqq_gemm,
fn_inputs,
test_utils=test_utils,
)


if __name__ == "__main__":
run_tests()
20 changes: 17 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch._dynamo.config
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

def device_sync(device):
Expand Down Expand Up @@ -211,6 +212,7 @@ def main(
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
fpx_weight_only,
uintx_weight_only,
autoquant,
Expand All @@ -235,8 +237,20 @@ def main(
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=128,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
else:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "embed-int8wo" in quantization:
Expand Down Expand Up @@ -474,7 +488,7 @@ def callback(x):
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
+'embed-int8wo'
+'embed-int8wo, marlin_qqq'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand Down
Loading

0 comments on commit 2690ff4

Please sign in to comment.