Skip to content

Commit

Permalink
Fix w8a8 benchmark and add Llama-3-8B (vllm-project#5562)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jun 17, 2024
1 parent 845a3f2 commit e2b85cf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
21 changes: 13 additions & 8 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl


def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch.mm(a, b)
Expand Down Expand Up @@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_i8_impl,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))

# cutlass impl
Expand All @@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,

timers = []

# pytorch impl w. bf16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))

# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
Expand All @@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,

# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.bfloat16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_bf16_scaled_mm"))
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.float16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_fp16_scaled_mm"))
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
return timers


Expand Down
6 changes: 6 additions & 0 deletions benchmarks/cutlass_benchmarks/weight_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
Expand Down

0 comments on commit e2b85cf

Please sign in to comment.