diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index a33e2660d760e..8fce76eb52f9b 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, // than better compute utilization thread_k = 128; thread_m = 128; - } else if (prob_n <= 256) { + } else { thread_k = 64; thread_m = 256; - } else { - thread_k = 32; - thread_m = 512; } + // Also had + // if prob_n > 256 + // thread_k = 32; + // thread_m = 512; + // but this is broken, + // TODO(Lucas, Alex M): figure out why } int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction @@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + TORCH_CHECK(a.dtype() == torch::kFloat16, + "A is not float16, currently only float16 is supported"); // Verify B device and strides TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); @@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify scales device and strides TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + TORCH_CHECK(b_scales.dtype() == torch::kFloat16, + "A is not float16, currently only float16 is supported"); // Alloc C matrix const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 3899ad1a325cf..5e047f4b099f1 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -50,6 +50,8 @@ (13, 17, 67), (26, 37, 13), (67, 13, 11), + (257, 13, 11), + (658, 13, 11), ] DTYPES = [torch.float16, torch.bfloat16]