diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 201dd403270f3..5dafe445e2b46 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -105,7 +105,6 @@ void weight_permute_gpu(const GPUContext& dev_ctx, input_data, output_data, numel, total_k, total_n); } } - template __global__ void per_channel_quant_gpu(const T* weight_data, int8_t* quanted_weight_data, @@ -161,6 +160,7 @@ __global__ void per_channel_quant_gpu(const T* weight_data, } } } + template void weight_quant_gpu(const GPUContext& dev_ctx, const T* weight_data, @@ -174,15 +174,8 @@ void weight_quant_gpu(const GPUContext& dev_ctx, constexpr int kBlockSize = 64; constexpr int kWarpNum = kBlockSize / kWarpSize; constexpr int kVectorSize = 128 / sizeof(T) / 8; - PADDLE_ENFORCE_EQ(total_n % kVectorSize, - 0, - phi::errors::PreconditionNotMet( - "Currently, weight_quant_gpu kernel only support n " - "with multiple of %d, please use", - kVectorSize)); int vec_total_n = total_n / kVectorSize; - int kGridSize = - max((vec_total_n + kBlockSize - 1) / kBlockSize, static_cast(1)); + int kGridSize = max(vec_total_n / kBlockSize, static_cast(1)); per_channel_quant_gpu<<>>( weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); } diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py index c7bbc1c658267..81f84f138e70b 100644 --- a/test/quantization/test_weight_only_linear.py +++ b/test/quantization/test_weight_only_linear.py @@ -399,47 +399,5 @@ def test_weightonly_linear_backward(self): np.testing.assert_allclose(quant_x.grad, x.grad, rtol=1e-3, atol=1e-3) -@unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", -) -class WeightOnlyLinearTestCase11(WeightOnlyLinearTestCase): - def config(self): - super().config() - self.dtype = 'float16' - self.weight_dtype = "int8" - self.in_features = 128 - self.out_features = 288 - - -@unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", -) -class WeightOnlyLinearTestCase12(WeightOnlyLinearTestCase): - def config(self): - super().config() - self.dtype = 'float16' - self.bias = False - self.weight_dtype = "int8" - self.in_features = 128 - self.out_features = 288 - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or get_cuda_version() < 11020 - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", -) -class WeightOnlyLinearTestCase13(WeightOnlyLinearTestCase): - def config(self): - super().config() - self.dtype = 'bfloat16' - self.weight_dtype = "int8" - self.in_features = 128 - self.out_features = 288 - - if __name__ == '__main__': unittest.main()