Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Oct 2, 2024
1 parent 76a5c3d commit 261b1c2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 60 deletions.
117 changes: 60 additions & 57 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import argparse
import copy
import os
import itertools
import math
import os
import pickle as pkl
import time
import pandas as pd
import torch
import torch.utils.benchmark as TBenchmark
from itertools import product
from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple

import pandas as pd
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

Expand All @@ -35,6 +35,7 @@
if NVTX_PROFILE:
import nvtx


def terse_type_name(dt):
return {
torch.bfloat16: "bf16",
Expand All @@ -60,6 +61,7 @@ class BenchmarkTensors:
w_ch_s: Optional[torch.Tensor]
w_tok_s: Optional[torch.Tensor]


@dataclass
class TypeConfig:
act_type: torch.dtype
Expand All @@ -70,6 +72,7 @@ class TypeConfig:
channel_scale_type: Optional[torch.dtype]
token_scale_type: Optional[torch.dtype]


def rand_data(shape, dtype=torch.float16, scale=1):
if dtype.is_floating_point:
return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
Expand All @@ -78,11 +81,11 @@ def rand_data(shape, dtype=torch.float16, scale=1):


def quantize_and_pack(atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
assert wtype.is_integer(), "TODO: support floating point weights"

w_ref, w_q, w_s, w_zp = quantize_weights(
Expand All @@ -97,19 +100,18 @@ def quantize_and_pack(atype: torch.dtype,
return w_ref, w_q, w_s, w_zp


def create_bench_tensors(shape: Tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int]) -> List[BenchmarkTensors]:
def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
group_size: Optional[int]) -> List[BenchmarkTensors]:
m, n, k = shape

# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
(k * n * types.weight_type.size_bits))

a = rand_data((m, k), types.act_type, scale=5)

benchmark_tensors: List[BenchmarkTensors] = []
for _ in range(num_weights):
w = rand_data((k, n), types.act_type, scale=5)
Expand All @@ -133,19 +135,18 @@ def create_bench_tensors(shape: Tuple[int, int, int],
rand_data((n,), types.channel_scale_type)
w_tok_s = None if types.token_scale_type is None else\
rand_data((m,), types.token_scale_type)

benchmark_tensors.append(
BenchmarkTensors(w_ref=w_ref,
a=a,
w_q=w_q_packed,
wtype=types.weight_type,
w_g_s=w_s,
w_g_zp=w_zp,
group_size=group_size,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s)
)

a=a,
w_q=w_q_packed,
wtype=types.weight_type,
w_g_s=w_s,
w_g_zp=w_zp,
group_size=group_size,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s))

return benchmark_tensors


Expand Down Expand Up @@ -178,20 +179,19 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
if bt.w_g_zp is None:
w_zp = torch.empty(0, dtype=torch.int, device=device)
else:
w_zp = marlin_zero_points(
bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits)
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
bt.w_ref.shape[1], bt.wtype.size_bits)

if bt.group_size is None:
w_s = torch.tensor([], device="cuda", dtype=torch.half)
else:
w_s = marlin_permute_scales(
bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size)
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
bt.w_ref.shape[1], bt.group_size)

sort_indices = torch.empty(0, dtype=torch.int, device=device)
g_idx = torch.empty(0, dtype=torch.int, device=device)
w_q = ops.gptq_marlin_repack(
bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1],
bt.wtype.size_bits)
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
bt.w_ref.shape[1], bt.wtype.size_bits)

if bt.a.dtype.is_floating_point:
assert bt.w_ch_s is None
Expand All @@ -213,19 +213,20 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8

if bt.w_ch_s is not None:
s_ch = bt.w_ch_s.to(torch.float32)
else:
else:
s_ch = torch.ones(bt.w_ref.shape[1],
dtype=torch.float32,
device=device)
dtype=torch.float32,
device=device)

if bt.w_tok_s is not None:
s_tok = bt.w_tok_s.to(torch.float32)
else:
s_tok = torch.ones(
bt.a.shape[0], dtype=torch.float32, device=device)
s_tok = torch.ones(bt.a.shape[0],
dtype=torch.float32,
device=device)

fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
b_q_weight=w_q,
Expand All @@ -244,7 +245,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
out_type=torch.dtype,
schedule=None) -> Callable:
w_q = bt.w_q.t().contiguous().t() # make col major
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
None if bt.w_g_s is None else bt.w_g_s.dtype)

w_g_zp = bt.w_g_zp
Expand Down Expand Up @@ -272,7 +273,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,

def bench_fns(label: str, sub_label: str, description: str,
fns: List[Callable]):

min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer(
stmt="""
Expand All @@ -288,16 +289,17 @@ def bench_fns(label: str, sub_label: str, description: str,
).blocked_autorange(min_run_time=min_run_time)

if NVTX_PROFILE:
with nvtx.annotate("mm-bench"):
with nvtx.annotate(f"{label}|{sub_label}|{description}"):
fns[0]()
with nvtx.annotate("mm-bench"), nvtx.annotate(
f"{label}|{sub_label}|{description}"):
fns[0]()

return res


_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None


def bench(types: TypeConfig,
group_size: int,
m: int,
Expand All @@ -322,7 +324,6 @@ def bench(types: TypeConfig,
if types.token_scale_type is not None:
name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"


timers = []
# pytorch impl
timers.append(
Expand All @@ -333,11 +334,12 @@ def bench(types: TypeConfig,

if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
timers.append(
bench_fns(label, sub_label,
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
[cutlass_scaled_mm_create_bench_fn(bt)
for bt in benchmark_tensors
]))
bench_fns(
label, sub_label,
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
cutlass_scaled_mm_create_bench_fn(bt)
for bt in benchmark_tensors
]))

if types.act_type != torch.float8_e4m3fn:
timers.append(
Expand All @@ -353,12 +355,14 @@ def bench(types: TypeConfig,
]))

if sweep_schedules:
global _SWEEP_SCHEDULES_RESULTS

print("Finding best schedule for machete")
best = None
best_schedule = None
schedules = ops.machete_supported_schedules(
a_type=types.act_type,
b_type=types.weight_type,
a_type=types.act_type,
b_type=types.weight_type,
group_scales_type=types.group_scale_type,
group_zeros_type=types.group_zero_type,
token_scales_type=types.token_scale_type,
Expand Down Expand Up @@ -412,7 +416,6 @@ def print_timers(timers: List[TMeasurement]):


def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:

types = TypeConfig(
act_type=args.act_type,
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
Expand All @@ -423,7 +426,7 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
channel_scale_type=args.channel_scale_type,
token_scale_type=args.token_scale_type,
)

results: List[TMeasurement] = []
for m, k, n in MKNs:
timers = bench(types,
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_machete_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class Tensors:
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)


def rand_data(shape, dtype=torch.float16, scale=1):
if dtype.is_floating_point:
return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
Expand Down Expand Up @@ -363,7 +364,6 @@ def test_machete_cuda_graph():
a = rand_data((m, k), torch.float16)
b = rand_data((k, n), torch.float16)
wtype = scalar_types.uint4b8
atype = torch.float16
stype = torch.float16
group_size = 128
zero_points = False
Expand Down
3 changes: 1 addition & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,7 @@ def machete_gemm_fake(
return torch.empty((m, n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
a_type: torch.dtype,
def machete_prepack_B_fake(b_q_weight: torch.Tensor, a_type: torch.dtype,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
Expand Down

0 comments on commit 261b1c2

Please sign in to comment.