diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 326ec02ca3cd7..b3f5b62086609 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -144,14 +144,14 @@ struct ScaledEpilogueBias using ScaleB = typename SUPER::ScaleB; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, ElementD, + cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, ElementD, + cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; using BiasDescriptor = diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 39de444be342d..d8e6d27b82484 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,7 +2,7 @@ Run `pytest tests/kernels/test_cutlass.py`. """ -from typing import Type +from typing import Optional, Type import pytest import torch @@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor): return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + output = (scale_a * (scale_b * (torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) + if bias is not None: + output = output + bias + + return output + + def cutlass_fp8_gemm_helper(m: int, n: int, k: int, per_token_act_quant: bool, per_out_channel_weight_quant: bool, - bias: bool, + use_bias: bool, out_dtype: Type[torch.dtype] = torch.bfloat16, device: str = "cuda"): # Test for a cutlass kernel with per-token activation quantization @@ -43,23 +58,19 @@ def cutlass_fp8_gemm_helper(m: int, m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 - scale_a = (torch.randn( - (m_a_scales, 1), device=device, dtype=torch.float32) / 10) - scale_b = (torch.randn( - (1, n_b_scales), device=device, dtype=torch.float32) / 10) - if bias: - # bias term should be > 1 so that the absolute tolerance can catch it - bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0 - out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t) + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 else: - out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype) - bias_t = 0 + bias = None - baseline = (torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)) + - bias_t).to(out_dtype) + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2) def cutlass_int8_gemm_helper(m: int, @@ -67,7 +78,7 @@ def cutlass_int8_gemm_helper(m: int, k: int, per_token_act_quant: bool, per_out_channel_weight_quant: bool, - bias: bool, + use_bias: bool, out_dtype: Type[torch.dtype] = torch.bfloat16, device: str = "cuda"): # Test for a cutlass kernel with per-token activation quantization @@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int, m_a_scales = m if per_token_act_quant else 1 n_b_scales = n if per_out_channel_weight_quant else 1 - scale_a = (torch.randn( - (m_a_scales, 1), device=device, dtype=torch.float32) / 10) - scale_b = (torch.randn( - (1, n_b_scales), device=device, dtype=torch.float32) / 10) + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) - if bias: - # bias term should be > 1 so that the absolute tolerance can catch it - bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0 - out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t) + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 else: - out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype) - bias_t = 0 + bias = None + + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - baseline = (torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)) + - bias_t).to(dtype=out_dtype) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) @@ -102,12 +110,12 @@ def cutlass_int8_gemm_helper(m: int, @pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, - per_out_ch: bool, bias: bool): - cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias) + per_out_ch: bool, use_bias: bool): + cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) @pytest.mark.parametrize("m", [512, 222, 33, 1]) @@ -115,70 +123,70 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, @pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, - per_out_ch: bool, bias: bool): - cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias) + per_out_ch: bool, use_bias: bool): + cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, out_dtype: Type[torch.dtype], - bias: bool): + use_bias: bool): cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, - bias, + use_bias, out_dtype=out_dtype) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, out_dtype: Type[torch.dtype], - bias: bool): + use_bias: bool): cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, - bias, + use_bias, out_dtype=out_dtype) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, - bias: bool, device: str): - cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias, + use_bias: bool, device: str): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias, torch.bfloat16, device) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, - bias: bool, device: str): + use_bias: bool, device: str): cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, - bias, + use_bias, out_dtype=torch.bfloat16, device=device) @@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, # kernel must handle any M thrown at it. @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, - bias: bool): + use_bias: bool): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias) + cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, + use_bias) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, - bias: bool): + use_bias: bool): for nk in range(32, 128, 32): for m in range(1, 128): cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, - bias) + use_bias) # Test working with a subset of A and B @@ -229,9 +238,11 @@ def test_cutlass_subset(): scale_a, scale_b, out_dtype=torch.bfloat16) - baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * - b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)