From b2e983186394f7d3dc1d14cb669df038a996f6c9 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Fri, 29 Aug 2025 03:54:24 -0700 Subject: [PATCH 01/11] Add fastdiv, use it in modulo and use modulo in rms_norm_f32 Fastdiv is much faster way to do integer division, which was identified as bottleneck in rms_norm_f32 --- ggml/src/ggml-cuda/common.cuh | 27 +++ ggml/src/ggml-cuda/norm.cu | 309 +++++++++++++++++++++++++--------- 2 files changed, 254 insertions(+), 82 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 85bc9e933bca5..aa5e1f67ca0c4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -563,6 +563,33 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); +} + +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, uint32_t mp, uint32_t L) { + // Compute high 32 bits of n * mp + uint32_t hi = __umulhi(n, mp); + // Apply the formula + return (hi + n) >> L; +} + +static __device__ __forceinline__ uint32_t modulo(uint32_t n, uint32_t divisor, int mp, uint32_t L) { + return n - fastdiv(n, mp, L) * divisor; +} + typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); static __device__ __forceinline__ float get_alibi_slope( diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index d5157d958b717..a1e2ecaa636af 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -105,29 +105,45 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } template -static __global__ void rms_norm_f32(const float * x, float * dst, - const int ncols, - const int64_t stride_row, - const int64_t stride_channel, - const int64_t stride_sample, - const float eps, - const float * mul = nullptr, - const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, - const int64_t mul_stride_sample = 0, - const int mul_ncols = 0, - const int mul_nrows = 0, - const int mul_nchannels = 0, - const int mul_nsamples = 0, - const float * add = nullptr, - const int64_t add_stride_row = 0, - const int64_t add_stride_channel = 0, - const int64_t add_stride_sample = 0, - const int add_ncols = 0, - const int add_nrows = 0, - const int add_nchannels = 0, - const int add_nsamples = 0) { - +static __global__ void rms_norm_f32(const float * x, + float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint32_t mul_ncols = 0, + const uint32_t mul_nrows = 0, + const uint32_t mul_nchannels = 0, + const uint32_t mul_nsamples = 0, + const uint32_t mp_mul_cols = 0, + const uint32_t L_mul_cols = 0, + const uint32_t mp_mul_rows = 0, + const uint32_t L_mul_rows = 0, + const uint32_t mp_mul_channels = 0, + const uint32_t L_mul_channels = 0, + const uint32_t mp_mul_samples = 0, + const uint32_t L_mul_samples = 0, + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint32_t add_ncols = 0, + const uint32_t add_nrows = 0, + const uint32_t add_nchannels = 0, + const uint32_t add_nsamples = 0, + const uint32_t mp_add_cols = 0, + const uint32_t L_add_cols = 0, + const uint32_t mp_add_rows = 0, + const uint32_t L_add_rows = 0, + const uint32_t mp_add_channels = 0, + const uint32_t L_add_channels = 0, + const uint32_t mp_add_samples = 0, + const uint32_t L_add_samples = 0) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -142,16 +158,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst, dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - const int mul_row = row % mul_nrows; - const int mul_channel = channel % mul_nchannels; - const int mul_sample = sample % mul_nsamples; - mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; + const uint32_t mul_row = modulo(row, mul_nrows, mp_mul_rows, L_mul_rows); + const uint32_t mul_channel = modulo(channel, mul_nchannels, mp_mul_channels, L_mul_channels); + const uint32_t mul_sample = modulo(sample, mul_nsamples, mp_mul_samples, L_mul_samples); + mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; } if constexpr (do_add) { - const int add_row = row % add_nrows; - const int add_channel = channel % add_nchannels; - const int add_sample = sample % add_nsamples; + const int add_row = modulo(row, add_nrows, mp_add_rows, L_add_rows); + const int add_channel = modulo(channel, add_nchannels, mp_add_channels, L_add_channels); + const int add_sample = modulo(sample, add_nsamples, mp_add_samples, L_add_samples); add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; } @@ -182,12 +198,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst, for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply && do_add) { - const int mul_col = col % mul_ncols; - const int add_col = col % add_ncols; - dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; + const int mul_col = modulo(col, mul_ncols, mp_mul_cols, L_mul_cols); + const int add_col = modulo(col, add_ncols, mp_add_cols, L_add_cols); + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; } else if constexpr (do_multiply) { - const int mul_col = col % mul_ncols; - dst[col] = scale * x[col] * mul[mul_col]; + const int mul_col = modulo(col, mul_ncols, mp_mul_cols, L_mul_cols); + dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; } @@ -362,69 +378,198 @@ static void rms_norm_f32_cuda( } } -static void rms_norm_mul_f32_cuda(const float * x, - const float * mul, - const float * add, - float * dst, - const int ncols, - const int nrows, - const int nchannels, - const int nsamples, - const int64_t stride_row, - const int64_t stride_channel, - const int64_t stride_sample, - const int64_t mul_stride_row, - const int64_t mul_stride_channel, - const int64_t mul_stride_sample, - const int mul_ncols, - const int mul_nrows, - const int mul_nchannels, - const int mul_nsamples, - const int64_t add_stride_row, - const int64_t add_stride_channel, - const int64_t add_stride_sample, - const int add_ncols, - const int add_nrows, - const int add_nchannels, - const int add_nsamples, - const float eps, - cudaStream_t stream) { +static void rms_norm_mul_f32_cuda(const float * x, + const float * mul, + const float * add, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const int64_t mul_stride_row, + const int64_t mul_stride_channel, + const int64_t mul_stride_sample, + const uint32_t mul_ncols, + const uint32_t mul_nrows, + const uint32_t mul_nchannels, + const uint32_t mul_nsamples, + const int64_t add_stride_row, + const int64_t add_stride_channel, + const int64_t add_stride_sample, + const uint32_t add_ncols, + const uint32_t add_nrows, + const uint32_t add_nchannels, + const uint32_t add_nsamples, + const float eps, + cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (mul == nullptr) { rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } if (add == nullptr) { + uint32_t mp_mul_cols, L_mul_cols; + init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols); + uint32_t mp_mul_rows, L_mul_rows; + init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows); + uint32_t mp_mul_channels, L_mul_channels; + init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels); + uint32_t mp_mul_samples, L_mul_samples; + init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + rms_norm_f32<<>>(x, + dst, + ncols, + stride_row, + stride_channel, + stride_sample, + eps, + mul, + mul_stride_row, + mul_stride_channel, + mul_stride_sample, + mul_ncols, + mul_nrows, + mul_nchannels, + mul_nsamples, + mp_mul_cols, + L_mul_cols, + mp_mul_rows, + L_mul_rows, + mp_mul_channels, + L_mul_channels, + mp_mul_samples, + L_mul_samples); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + rms_norm_f32<1024, true><<>>(x, + dst, + ncols, + stride_row, + stride_channel, + stride_sample, + eps, + mul, + mul_stride_row, + mul_stride_channel, + mul_stride_sample, + mul_ncols, + mul_nrows, + mul_nchannels, + mul_nsamples, + mp_mul_cols, + L_mul_cols, + mp_mul_rows, + L_mul_rows, + mp_mul_channels, + L_mul_channels, + mp_mul_samples, + L_mul_samples); } } else { + uint32_t mp_mul_cols, L_mul_cols; + init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols); + uint32_t mp_mul_rows, L_mul_rows; + init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows); + uint32_t mp_mul_channels, L_mul_channels; + init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels); + uint32_t mp_mul_samples, L_mul_samples; + init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples); + + uint32_t mp_add_cols, L_add_cols; + init_fastdiv_values(add_ncols, mp_add_cols, L_add_cols); + uint32_t mp_add_rows, L_add_rows; + init_fastdiv_values(add_nrows, mp_add_rows, L_add_rows); + uint32_t mp_add_channels, L_add_channels; + init_fastdiv_values(add_nchannels, mp_add_channels, L_add_channels); + uint32_t mp_add_samples, L_add_samples; + init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, - add, add_stride_row, add_stride_channel, add_stride_sample, - add_ncols, add_nrows, add_nchannels, add_nsamples); + rms_norm_f32<<>>(x, + dst, + ncols, + stride_row, + stride_channel, + stride_sample, + eps, + mul, + mul_stride_row, + mul_stride_channel, + mul_stride_sample, + mul_ncols, + mul_nrows, + mul_nchannels, + mul_nsamples, + mp_mul_cols, + L_mul_cols, + mp_mul_rows, + L_mul_rows, + mp_mul_channels, + L_mul_channels, + mp_mul_samples, + L_mul_samples, + add, + add_stride_row, + add_stride_channel, + add_stride_sample, + add_ncols, + add_nrows, + add_nchannels, + add_nsamples, + mp_add_cols, + L_add_cols, + mp_add_rows, + L_add_rows, + mp_add_channels, + L_add_channels, + mp_add_samples, + L_add_samples); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, - add, add_stride_row, add_stride_channel, add_stride_sample, - add_ncols, add_nrows, add_nchannels, add_nsamples); + rms_norm_f32<1024, true, true><<>>(x, + dst, + ncols, + stride_row, + stride_channel, + stride_sample, + eps, + mul, + mul_stride_row, + mul_stride_channel, + mul_stride_sample, + mul_ncols, + mul_nrows, + mul_nchannels, + mul_nsamples, + mp_mul_cols, + L_mul_cols, + mp_mul_rows, + L_mul_rows, + mp_mul_channels, + L_mul_channels, + mp_mul_samples, + L_mul_samples, + add, + add_stride_row, + add_stride_channel, + add_stride_sample, + add_ncols, + add_nrows, + add_nchannels, + add_nsamples, + mp_add_cols, + L_add_cols, + mp_add_rows, + L_add_rows, + mp_add_channels, + L_add_channels, + mp_add_samples, + L_add_samples); } } } From bcc6c777ce4e3477084789bd656a0f52ae02aa75 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Fri, 29 Aug 2025 03:55:55 -0700 Subject: [PATCH 02/11] Support more `block_size` values in `rms_norm_f32` This makes us more flexible in selecting the optimal threads w.r.t paralellizing across a col vs. launch-overheads of threads and mio throttles --- ggml/src/ggml-cuda/norm.cu | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index a1e2ecaa636af..2243affbd996c 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -181,15 +181,18 @@ static __global__ void rms_norm_f32(const float * x, // sum up partial sums tmp = warp_reduce_sum(tmp); if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); + static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); - tmp = s_sum[lane_id]; + tmp = 0.0f; + if (lane_id < (block_size / WARP_SIZE)) { + tmp = s_sum[lane_id]; + } tmp = warp_reduce_sum(tmp); } @@ -370,8 +373,8 @@ static void rms_norm_f32_cuda( const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); @@ -420,8 +423,8 @@ static void rms_norm_mul_f32_cuda(const float * x, uint32_t mp_mul_samples, L_mul_samples; init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true><<>>(x, dst, ncols, stride_row, @@ -489,8 +492,8 @@ static void rms_norm_mul_f32_cuda(const float * x, uint32_t mp_add_samples, L_add_samples; init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true, true><<>>(x, dst, ncols, stride_row, From 30ab9ae4670ca3c526ecba5824361b74d1a8983d Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 2 Sep 2025 15:50:38 +0200 Subject: [PATCH 03/11] Update ggml/src/ggml-cuda/common.cuh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index aa5e1f67ca0c4..b21f408298df2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -586,7 +586,7 @@ static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, uint32_t mp, uint return (hi + n) >> L; } -static __device__ __forceinline__ uint32_t modulo(uint32_t n, uint32_t divisor, int mp, uint32_t L) { +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, uint32_t divisor, int mp, uint32_t L) { return n - fastdiv(n, mp, L) * divisor; } From 18242c3d723d9bceabb43207ef877c4e9e8b2115 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 2 Sep 2025 06:53:05 -0700 Subject: [PATCH 04/11] Replace modulo with fastmodulo in `rms_norm_f32` --- ggml/src/ggml-cuda/norm.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 2243affbd996c..9ed4b60e3ff5c 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -158,16 +158,16 @@ static __global__ void rms_norm_f32(const float * x, dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - const uint32_t mul_row = modulo(row, mul_nrows, mp_mul_rows, L_mul_rows); - const uint32_t mul_channel = modulo(channel, mul_nchannels, mp_mul_channels, L_mul_channels); - const uint32_t mul_sample = modulo(sample, mul_nsamples, mp_mul_samples, L_mul_samples); + const uint32_t mul_row = fastmodulo(row, mul_nrows, mp_mul_rows, L_mul_rows); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels, mp_mul_channels, L_mul_channels); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples, mp_mul_samples, L_mul_samples); mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; } if constexpr (do_add) { - const int add_row = modulo(row, add_nrows, mp_add_rows, L_add_rows); - const int add_channel = modulo(channel, add_nchannels, mp_add_channels, L_add_channels); - const int add_sample = modulo(sample, add_nsamples, mp_add_samples, L_add_samples); + const int add_row = fastmodulo(row, add_nrows, mp_add_rows, L_add_rows); + const int add_channel = fastmodulo(channel, add_nchannels, mp_add_channels, L_add_channels); + const int add_sample = fastmodulo(sample, add_nsamples, mp_add_samples, L_add_samples); add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; } @@ -201,11 +201,11 @@ static __global__ void rms_norm_f32(const float * x, for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply && do_add) { - const int mul_col = modulo(col, mul_ncols, mp_mul_cols, L_mul_cols); - const int add_col = modulo(col, add_ncols, mp_add_cols, L_add_cols); + const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols); + const int add_col = fastmodulo(col, add_ncols, mp_add_cols, L_add_cols); dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; } else if constexpr (do_multiply) { - const int mul_col = modulo(col, mul_ncols, mp_mul_cols, L_mul_cols); + const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols); dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; From 0129866abecd5d7f83839ff7c404370386db7709 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 2 Sep 2025 06:49:03 -0700 Subject: [PATCH 05/11] Use `BinPackArguments=true` for formating function calls Will file a separate PR to adjust .clang-format file --- ggml/src/ggml-cuda/norm.cu | 146 ++++++------------------------------- 1 file changed, 22 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 9ed4b60e3ff5c..9d0168bee464b 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -424,54 +424,16 @@ static void rms_norm_mul_f32_cuda(const float * x, init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><<>>(x, - dst, - ncols, - stride_row, - stride_channel, - stride_sample, - eps, - mul, - mul_stride_row, - mul_stride_channel, - mul_stride_sample, - mul_ncols, - mul_nrows, - mul_nchannels, - mul_nsamples, - mp_mul_cols, - L_mul_cols, - mp_mul_rows, - L_mul_rows, - mp_mul_channels, - L_mul_channels, - mp_mul_samples, - L_mul_samples); + rms_norm_f32<256, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, + mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, - dst, - ncols, - stride_row, - stride_channel, - stride_sample, - eps, - mul, - mul_stride_row, - mul_stride_channel, - mul_stride_sample, - mul_ncols, - mul_nrows, - mul_nchannels, - mul_nsamples, - mp_mul_cols, - L_mul_cols, - mp_mul_rows, - L_mul_rows, - mp_mul_channels, - L_mul_channels, - mp_mul_samples, - L_mul_samples); + rms_norm_f32<1024, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, + mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples); } } else { uint32_t mp_mul_cols, L_mul_cols; @@ -493,86 +455,22 @@ static void rms_norm_mul_f32_cuda(const float * x, init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><<>>(x, - dst, - ncols, - stride_row, - stride_channel, - stride_sample, - eps, - mul, - mul_stride_row, - mul_stride_channel, - mul_stride_sample, - mul_ncols, - mul_nrows, - mul_nchannels, - mul_nsamples, - mp_mul_cols, - L_mul_cols, - mp_mul_rows, - L_mul_rows, - mp_mul_channels, - L_mul_channels, - mp_mul_samples, - L_mul_samples, - add, - add_stride_row, - add_stride_channel, - add_stride_sample, - add_ncols, - add_nrows, - add_nchannels, - add_nsamples, - mp_add_cols, - L_add_cols, - mp_add_rows, - L_add_rows, - mp_add_channels, - L_add_channels, - mp_add_samples, - L_add_samples); + rms_norm_f32<256, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, + mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels, + add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels, + mp_add_samples, L_add_samples); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<>>(x, - dst, - ncols, - stride_row, - stride_channel, - stride_sample, - eps, - mul, - mul_stride_row, - mul_stride_channel, - mul_stride_sample, - mul_ncols, - mul_nrows, - mul_nchannels, - mul_nsamples, - mp_mul_cols, - L_mul_cols, - mp_mul_rows, - L_mul_rows, - mp_mul_channels, - L_mul_channels, - mp_mul_samples, - L_mul_samples, - add, - add_stride_row, - add_stride_channel, - add_stride_sample, - add_ncols, - add_nrows, - add_nchannels, - add_nsamples, - mp_add_cols, - L_add_cols, - mp_add_rows, - L_add_rows, - mp_add_channels, - L_add_channels, - mp_add_samples, - L_add_samples); + rms_norm_f32<1024, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, + mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels, + add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels, + mp_add_samples, L_add_samples); } } } From 48afab4ba2ab9a319905e04c95129e6b6d6710a6 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 3 Sep 2025 10:43:29 +0200 Subject: [PATCH 06/11] Update ggml/src/ggml-cuda/common.cuh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index b21f408298df2..f5c1fe1bdfef7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -581,7 +581,7 @@ static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, uint32_t mp, uint32_t L) { // Compute high 32 bits of n * mp - uint32_t hi = __umulhi(n, mp); + const uint32_t hi = __umulhi(n, mp); // Apply the formula return (hi + n) >> L; } From 741465255859830b77d43b45c79c6b059d98d5a9 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 3 Sep 2025 11:51:42 +0200 Subject: [PATCH 07/11] Pack fastdiv/fastmodulo constants into uint2/uint3 objects By packing constants to be used together into a struct, we are less likely to make errors. --- ggml/src/ggml-cuda/common.cuh | 24 ++++-- ggml/src/ggml-cuda/norm.cu | 140 +++++++++++++--------------------- 2 files changed, 69 insertions(+), 95 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index f5c1fe1bdfef7..ffc9aa4bf2422 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -569,25 +569,33 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { // and a shift: // // n/d = (mulhi(n, mp) + n) >> L; -static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { +static const uint2 init_fastdiv_values(uint32_t d) { // compute L = ceil(log2(d)); - L = 0; + uint32_t L = 0; while (L < 32 && (uint32_t{ 1 } << L) < d) { L++; } - mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + return make_uint2(mp, L); } -static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, uint32_t mp, uint32_t L) { +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint2 div_consts) { // Compute high 32 bits of n * mp - const uint32_t hi = __umulhi(n, mp); + const uint32_t hi = __umulhi(n, div_consts.x); // Apply the formula - return (hi + n) >> L; + return (hi + n) >> div_consts.y; } -static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, uint32_t divisor, int mp, uint32_t L) { - return n - fastdiv(n, mp, L) * divisor; +static const uint3 init_fastmodulo_values(uint32_t d) { + // uint3 contains in + const uint2 fastdiv = init_fastdiv_values(d); + return make_uint3(fastdiv.x, fastdiv.y, d); +} + +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 div_consts_divisor) { + // expects div_consts_divisor to contain in + return n - fastdiv(n, make_uint2(div_consts_divisor.x, div_consts_divisor.y)) * div_consts_divisor.z; } typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 9d0168bee464b..a8ebcf0e2a029 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -105,45 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } template -static __global__ void rms_norm_f32(const float * x, - float * dst, - const int ncols, - const int64_t stride_row, - const int64_t stride_channel, - const int64_t stride_sample, - const float eps, - const float * mul = nullptr, - const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, - const int64_t mul_stride_sample = 0, - const uint32_t mul_ncols = 0, - const uint32_t mul_nrows = 0, - const uint32_t mul_nchannels = 0, - const uint32_t mul_nsamples = 0, - const uint32_t mp_mul_cols = 0, - const uint32_t L_mul_cols = 0, - const uint32_t mp_mul_rows = 0, - const uint32_t L_mul_rows = 0, - const uint32_t mp_mul_channels = 0, - const uint32_t L_mul_channels = 0, - const uint32_t mp_mul_samples = 0, - const uint32_t L_mul_samples = 0, - const float * add = nullptr, - const int64_t add_stride_row = 0, - const int64_t add_stride_channel = 0, - const int64_t add_stride_sample = 0, - const uint32_t add_ncols = 0, - const uint32_t add_nrows = 0, - const uint32_t add_nchannels = 0, - const uint32_t add_nsamples = 0, - const uint32_t mp_add_cols = 0, - const uint32_t L_add_cols = 0, - const uint32_t mp_add_rows = 0, - const uint32_t L_add_rows = 0, - const uint32_t mp_add_channels = 0, - const uint32_t L_add_channels = 0, - const uint32_t mp_add_samples = 0, - const uint32_t L_add_samples = 0) { +static __global__ void rms_norm_f32(const float * x, + float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint3 mul_ncols_packed = make_uint3(0, 0, 0), + const uint3 mul_nrows_packed = make_uint3(0, 0, 0), + const uint3 mul_nchannels_packed = make_uint3(0, 0, 0), + const uint3 mul_nsamples_packed = make_uint3(0, 0, 0), + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint3 add_ncols_packed = make_uint3(0, 0, 0), + const uint3 add_nrows_packed = make_uint3(0, 0, 0), + const uint3 add_nchannels_packed = make_uint3(0, 0, 0), + const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -158,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - const uint32_t mul_row = fastmodulo(row, mul_nrows, mp_mul_rows, L_mul_rows); - const uint32_t mul_channel = fastmodulo(channel, mul_nchannels, mp_mul_channels, L_mul_channels); - const uint32_t mul_sample = fastmodulo(sample, mul_nsamples, mp_mul_samples, L_mul_samples); + const uint32_t mul_row = fastmodulo(row, mul_nrows_packed); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed); mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; } if constexpr (do_add) { - const int add_row = fastmodulo(row, add_nrows, mp_add_rows, L_add_rows); - const int add_channel = fastmodulo(channel, add_nchannels, mp_add_channels, L_add_channels); - const int add_sample = fastmodulo(sample, add_nsamples, mp_add_samples, L_add_samples); + const int add_row = fastmodulo(row, add_nrows_packed); + const int add_channel = fastmodulo(channel, add_nchannels_packed); + const int add_sample = fastmodulo(sample, add_nsamples_packed); add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; } @@ -201,11 +185,11 @@ static __global__ void rms_norm_f32(const float * x, for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply && do_add) { - const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols); - const int add_col = fastmodulo(col, add_ncols, mp_add_cols, L_add_cols); + const int mul_col = fastmodulo(col, mul_ncols_packed); + const int add_col = fastmodulo(col, add_ncols_packed); dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; } else if constexpr (do_multiply) { - const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols); + const int mul_col = fastmodulo(col, mul_ncols_packed); dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; @@ -414,63 +398,45 @@ static void rms_norm_mul_f32_cuda(const float * x, return; } if (add == nullptr) { - uint32_t mp_mul_cols, L_mul_cols; - init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols); - uint32_t mp_mul_rows, L_mul_rows; - init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows); - uint32_t mp_mul_channels, L_mul_channels; - init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels); - uint32_t mp_mul_samples, L_mul_samples; - init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples); + uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols); + uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows); + uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels); + uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); rms_norm_f32<256, true><<>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, - mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024, true><<>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, - mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } } else { - uint32_t mp_mul_cols, L_mul_cols; - init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols); - uint32_t mp_mul_rows, L_mul_rows; - init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows); - uint32_t mp_mul_channels, L_mul_channels; - init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels); - uint32_t mp_mul_samples, L_mul_samples; - init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples); - - uint32_t mp_add_cols, L_add_cols; - init_fastdiv_values(add_ncols, mp_add_cols, L_add_cols); - uint32_t mp_add_rows, L_add_rows; - init_fastdiv_values(add_nrows, mp_add_rows, L_add_rows); - uint32_t mp_add_channels, L_add_channels; - init_fastdiv_values(add_nchannels, mp_add_channels, L_add_channels); - uint32_t mp_add_samples, L_add_samples; - init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples); + uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols); + uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows); + uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels); + uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples); + + uint3 add_ncols_packed = init_fastmodulo_values(add_ncols); + uint3 add_nrows_packed = init_fastmodulo_values(add_nrows); + uint3 add_nchannels_packed = init_fastmodulo_values(add_nchannels); + uint3 add_nsamples_packed = init_fastmodulo_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); rms_norm_f32<256, true, true><<>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, - mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add, - add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels, - add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels, - mp_add_samples, L_add_samples); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024, true, true><<>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols, - mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add, - add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels, - add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels, - mp_add_samples, L_add_samples); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); } } } From 0a76b118fe177e54976085d75441fe8f558973e9 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 3 Sep 2025 15:22:50 +0200 Subject: [PATCH 08/11] Rename function parameter of fastmodulo `modulo_consts` is more fitting/descriptive --- ggml/src/ggml-cuda/common.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ffc9aa4bf2422..10557df9d81a7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -593,9 +593,9 @@ static const uint3 init_fastmodulo_values(uint32_t d) { return make_uint3(fastdiv.x, fastdiv.y, d); } -static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 div_consts_divisor) { - // expects div_consts_divisor to contain in - return n - fastdiv(n, make_uint2(div_consts_divisor.x, div_consts_divisor.y)) * div_consts_divisor.z; +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 modulo_consts) { + // expects modulo_consts to contain in (see init_fastmodulo_values function) + return n - fastdiv(n, make_uint2(modulo_consts.x, modulo_consts.y)) * modulo_consts.z; } typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); From 8b1e9370c1ad83594b442455a921b51e89baca30 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 3 Sep 2025 16:56:06 +0200 Subject: [PATCH 09/11] Use uint3 for both `fastdiv` and `fastmodulo` The compiler seems to reliably optimize away the unused .z component in the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3 --- ggml/src/ggml-cuda/common.cuh | 21 +++++++++------------ ggml/src/ggml-cuda/norm.cu | 26 +++++++++++++------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10557df9d81a7..760d48e146d43 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -569,7 +569,7 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { // and a shift: // // n/d = (mulhi(n, mp) + n) >> L; -static const uint2 init_fastdiv_values(uint32_t d) { +static const uint3 init_fastdiv_values(uint32_t d) { // compute L = ceil(log2(d)); uint32_t L = 0; while (L < 32 && (uint32_t{ 1 } << L) < d) { @@ -577,25 +577,22 @@ static const uint2 init_fastdiv_values(uint32_t d) { } uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); - return make_uint2(mp, L); + // pack divisor as well to reduce error surface + return make_uint3(mp, L, d); } -static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint2 div_consts) { +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 div_consts) { + // expects div_consts to contain in + // div_consts.z is unused and optimized away by the compiler. // Compute high 32 bits of n * mp const uint32_t hi = __umulhi(n, div_consts.x); - // Apply the formula + // add n, apply bit shift return (hi + n) >> div_consts.y; } -static const uint3 init_fastmodulo_values(uint32_t d) { - // uint3 contains in - const uint2 fastdiv = init_fastdiv_values(d); - return make_uint3(fastdiv.x, fastdiv.y, d); -} - static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 modulo_consts) { - // expects modulo_consts to contain in (see init_fastmodulo_values function) - return n - fastdiv(n, make_uint2(modulo_consts.x, modulo_consts.y)) * modulo_consts.z; + // expects modulo_consts to contain in (see init_fastdiv_values) + return n - fastdiv(n, modulo_consts) * modulo_consts.z; } typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index a8ebcf0e2a029..ec63cbed708a0 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -398,10 +398,10 @@ static void rms_norm_mul_f32_cuda(const float * x, return; } if (add == nullptr) { - uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols); - uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows); - uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels); - uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples); + uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); rms_norm_f32<256, true><<>>( @@ -414,15 +414,15 @@ static void rms_norm_mul_f32_cuda(const float * x, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } } else { - uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols); - uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows); - uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels); - uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples); - - uint3 add_ncols_packed = init_fastmodulo_values(add_ncols); - uint3 add_nrows_packed = init_fastmodulo_values(add_nrows); - uint3 add_nchannels_packed = init_fastmodulo_values(add_nchannels); - uint3 add_nsamples_packed = init_fastmodulo_values(add_nsamples); + uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + + uint3 add_ncols_packed = init_fastdiv_values(add_ncols); + uint3 add_nrows_packed = init_fastdiv_values(add_nrows); + uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); + uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); rms_norm_f32<256, true, true><<>>( From f0dabf2902d44c8761d02810b34e4c099f345977 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 3 Sep 2025 17:35:05 +0200 Subject: [PATCH 10/11] More constrained type declarations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/norm.cu | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index ec63cbed708a0..4f153c5718ead 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -398,10 +398,10 @@ static void rms_norm_mul_f32_cuda(const float * x, return; } if (add == nullptr) { - uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); - uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); - uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); - uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); rms_norm_f32<256, true><<>>( @@ -414,15 +414,15 @@ static void rms_norm_mul_f32_cuda(const float * x, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } } else { - uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); - uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); - uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); - uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); - - uint3 add_ncols_packed = init_fastdiv_values(add_ncols); - uint3 add_nrows_packed = init_fastdiv_values(add_nrows); - uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); - uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + + const uint3 add_ncols_packed = init_fastdiv_values(add_ncols); + const uint3 add_nrows_packed = init_fastdiv_values(add_nrows); + const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); + const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); rms_norm_f32<256, true, true><<>>( From 8bde72b515c000cec726aae3741266ea31602c8d Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 3 Sep 2025 17:38:26 +0200 Subject: [PATCH 11/11] Rename fastdiv and fastmodulo variables to shared variable name As suggest by @JohannesGaessler, this increases clarity of the intended use --- ggml/src/ggml-cuda/common.cuh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 760d48e146d43..a2dc26eab7e4c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -581,18 +581,18 @@ static const uint3 init_fastdiv_values(uint32_t d) { return make_uint3(mp, L, d); } -static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 div_consts) { - // expects div_consts to contain in - // div_consts.z is unused and optimized away by the compiler. +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in + // fastdiv_values.z is unused and optimized away by the compiler. // Compute high 32 bits of n * mp - const uint32_t hi = __umulhi(n, div_consts.x); + const uint32_t hi = __umulhi(n, fastdiv_values.x); // add n, apply bit shift - return (hi + n) >> div_consts.y; + return (hi + n) >> fastdiv_values.y; } -static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 modulo_consts) { - // expects modulo_consts to contain in (see init_fastdiv_values) - return n - fastdiv(n, modulo_consts) * modulo_consts.z; +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z; } typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);