From 565770cb5dacc0fbf26e96c31588524cad7bddc1 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 6 Nov 2024 15:39:28 +0000 Subject: [PATCH] review comments Signed-off-by: Lucas Wilkinson --- .../vllm_numeric_conversion.cuh | 11 ++-- csrc/quantization/machete/generate.py | 52 ++++++++++--------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index fa44753c00fcb..a7dc05b426fae 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -31,11 +31,14 @@ struct InterleavedNumericArrayConverter { static result_type convert(source_type const& source) { if (cute::elect_one_sync()) { if constexpr (std::is_same_v) { - printf(" %s <= %s (N = %d, IlvBlkLayout = void)\n", nameof_v, - nameof_v, N); + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); } else { - printf(" %s <= %s (N = %d, size(IlvBlkLayout{}) = %d)\n", nameof_v, - nameof_v, N, size(IlvBlkLayout{})); + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); } __brkpt(); } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 1d8d03ece9da4..f9b33221d2a5c 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -105,7 +105,7 @@ TORCH_CHECK_NOT_IMPLEMENTED( false, "machete_mm(..) is not implemented for " "a_type=", args.A.scalar_type(), - ", b_type=", args.b_type.str(), + ", b_type=", args.btype.str(), ", out_type=", out_type, ", with_group_scale_type=", maybe_g_scales_type ? toString(*maybe_g_scales_type) : "None", @@ -525,40 +525,42 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): impl_configs = [] - GPTQ_kernel_types = list((TypeConfig( - a=a, - b=b, - b_group_scale=a, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.void, - b_token_scale=DataType.void, - out=a, - accumulator=DataType.f32, - ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for a in (DataType.f16, DataType.bf16))) + GPTQ_kernel_type_configs = list( + TypeConfig( + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.void, + b_token_scale=DataType.void, + out=a, + accumulator=DataType.f32, + ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16)) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(GPTQ_kernel_types, + for x in zip(GPTQ_kernel_type_configs, itertools.repeat(get_unique_schedules(default_heuristic)), itertools.repeat(default_heuristic)) ] - AWQ_kernel_types = list((TypeConfig( - a=a, - b=b, - b_group_scale=a, - b_group_zeropoint=a, - b_channel_scale=DataType.void, - b_token_scale=DataType.void, - out=a, - accumulator=DataType.f32, - ) for b in (DataType.u4, DataType.u8) - for a in (DataType.f16, DataType.bf16))) + AWQ_kernel_type_configs = list( + TypeConfig( + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=a, + b_channel_scale=DataType.void, + b_token_scale=DataType.void, + out=a, + accumulator=DataType.f32, + ) for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16)) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(AWQ_kernel_types, + for x in zip(AWQ_kernel_type_configs, itertools.repeat(get_unique_schedules(default_heuristic)), itertools.repeat(default_heuristic)) ]