Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix compute datatype for cutlass 3.x epilogues #5931

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ struct ScaledEpilogueBias
using ScaleB = typename SUPER::ScaleB;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, ElementD,
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, ElementD,
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using BiasDescriptor =
Expand Down
125 changes: 68 additions & 57 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Type
from typing import Optional, Type

import pytest
import torch
Expand All @@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias

return output


def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
Expand All @@ -43,31 +58,27 @@ def cutlass_fp8_gemm_helper(m: int,
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1

scale_a = (torch.randn(
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
bias = None

baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)


def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
Expand All @@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1

scale_a = (torch.randn(
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))

if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
bias = None

out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(dtype=out_dtype)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)


Expand All @@ -102,83 +110,83 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
per_out_ch: bool, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)


@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
per_out_ch: bool, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype],
bias: bool):
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
use_bias,
out_dtype=out_dtype)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype],
bias: bool):
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
use_bias,
out_dtype=out_dtype)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
use_bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
torch.bfloat16, device)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
bias: bool, device: str):
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
use_bias,
out_dtype=torch.bfloat16,
device=device)

Expand All @@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
bias)
use_bias)


# Test working with a subset of A and B
Expand All @@ -229,9 +238,11 @@ def test_cutlass_subset():
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)

Expand Down
Loading