From 5d6b098ad8f060a9b0fbf549b598ae714e16e382 Mon Sep 17 00:00:00 2001 From: Yuval Shekel Date: Sun, 19 May 2024 14:47:43 +0000 Subject: [PATCH] fix: ntt mixed-radix bug regarding large ntts (and/or batch) in some cases 32b values would wrap around and cause invalid accesses to wrong elements and memory addresses --- icicle/src/ntt/kernel_ntt.cu | 66 ++++++++++++++++++++---------------- icicle/src/ntt/thread_ntt.cu | 26 +++++++------- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/icicle/src/ntt/kernel_ntt.cu b/icicle/src/ntt/kernel_ntt.cu index 3166b334cf..8dc23a00e1 100644 --- a/icicle/src/ntt/kernel_ntt.cu +++ b/icicle/src/ntt/kernel_ntt.cu @@ -73,14 +73,14 @@ namespace mxntt { // if its index is the smallest number in the group -> do the memory transformation // else --> do nothing - const uint32_t size = 1 << log_size; - const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; - const uint32_t idx = columns_batch ? tid / batch_size : tid % size; - const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size; - if (tid >= size * batch_size) return; - - uint32_t next_element = idx; - uint32_t group[MAX_GROUP_SIZE]; + const uint64_t size = 1UL << log_size; + const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const uint64_t idx = columns_batch ? tid / batch_size : tid % size; + const uint64_t batch_idx = columns_batch ? tid % batch_size : tid / size; + if (tid >= uint64_t(size) * batch_size) return; + + uint64_t next_element = idx; + uint64_t group[MAX_GROUP_SIZE]; group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx; uint32_t i = 1; @@ -114,11 +114,13 @@ namespace mxntt { bool is_normalize, S inverse_N) { - uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= (1 << log_size) * batch_size) return; - uint32_t rd = tid; - uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) + - generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type); + const uint64_t size = 1UL << log_size; + const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x; + if (tid >= uint64_t(size) * batch_size) return; + + uint64_t rd = tid; + uint64_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) + + generalized_rev((tid / columns_batch_size) & (size - 1), log_size, dit, fast_tw, rev_type); arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd]; } @@ -131,14 +133,14 @@ namespace mxntt { uint32_t columns_batch_size, S* scalar_vec, int step, - int n_scalars, + uint32_t n_scalars, uint32_t log_size, eRevType rev_type, bool fast_tw, E* out_vec) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= size * batch_size) return; + uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x; + if (tid >= uint64_t(size) * batch_size) return; int64_t scalar_id = (tid / columns_batch_size) % size; if (rev_type != eRevType::None) { // Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the @@ -148,8 +150,7 @@ namespace mxntt { // Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the // corresponding element in scalars vec. const bool dif = rev_type == eRevType::NaturalToMixedRev; - scalar_id = - generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type); + scalar_id = generalized_rev((tid / columns_batch_size) & (size - 1), log_size, !dif, fast_tw, rev_type); } out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid]; } @@ -523,9 +524,9 @@ namespace mxntt { } template - __global__ void normalize_kernel(E* data, S norm_factor, uint32_t size) + __global__ void normalize_kernel(E* data, S norm_factor, uint64_t size) { - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint64_t tid = uint64_t(blockIdx.x) * blockDim.x + threadIdx.x; if (tid >= size) return; data[tid] = data[tid] * norm_factor; } @@ -786,7 +787,7 @@ namespace mxntt { columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(4), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(4), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -804,7 +805,7 @@ namespace mxntt { columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(5), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(5), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -816,7 +817,7 @@ namespace mxntt { in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(6), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(6), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -844,12 +845,12 @@ namespace mxntt { columns_batch, 0, inv, dit, fast_tw); } if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(8), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(8), (1UL << log_size) * batch_size); return CHK_LAST(); } // general case: - uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size); + uint32_t nof_blocks = (1UL << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size); if (dit) { for (int i = 0; i < 5; i++) { uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i]; @@ -900,7 +901,7 @@ namespace mxntt { } if (normalize) normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>( - out, S::inv_log_size(log_size), (1 << log_size) * batch_size); + out, S::inv_log_size(log_size), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -926,13 +927,18 @@ namespace mxntt { { CHK_INIT_IF_RETURN(); - const int logn = int(log2(ntt_size)); - const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64; - const int NOF_THREADS = min(64, (1 << logn) * batch_size); + const uint64_t logn = uint64_t(log2(ntt_size)); + const uint64_t NOF_BLOCKS_64b = (uint64_t(ntt_size) * batch_size + 64 - 1) / 64; + const uint32_t NOF_THREADS = min(64UL, uint64_t(ntt_size) * batch_size); + // CUDA grid is 32b fields. Assert that I don't need a larger grid. + const uint32_t NOF_BLOCKS = NOF_BLOCKS_64b; + if (NOF_BLOCKS != NOF_BLOCKS_64b) { + THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "NTT dimensions (ntt_size, batch) are too large. Unsupported!"); + } bool is_normalize = is_inverse; const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset; - const int n_twiddles = 1 << max_logn; + const uint32_t n_twiddles = 1U << max_logn; // Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication eRevType reverse_input = None, reverse_output = None, reverse_coset = None; bool dit = false; diff --git a/icicle/src/ntt/thread_ntt.cu b/icicle/src/ntt/thread_ntt.cu index 8b3e4a56d2..6882d2f654 100644 --- a/icicle/src/ntt/thread_ntt.cu +++ b/icicle/src/ntt/thread_ntt.cu @@ -10,8 +10,8 @@ struct stage_metadata { uint32_t th_stride; uint32_t ntt_block_size; uint32_t batch_id; - uint32_t ntt_block_id; uint32_t ntt_inp_id; + uint64_t ntt_block_id; }; #define STAGE_SIZES_DATA \ @@ -194,7 +194,7 @@ public: } DEVICE_INLINE void - loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + loadGlobalData(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + @@ -210,7 +210,7 @@ public: } DEVICE_INLINE void loadGlobalDataColumnBatch( - const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * @@ -224,7 +224,7 @@ public: } DEVICE_INLINE void - storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + storeGlobalData(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + @@ -240,7 +240,7 @@ public: } DEVICE_INLINE void storeGlobalDataColumnBatch( - E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * @@ -254,7 +254,7 @@ public: } DEVICE_INLINE void - loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + loadGlobalData32(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + @@ -273,7 +273,7 @@ public: } DEVICE_INLINE void loadGlobalData32ColumnBatch( - const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * @@ -290,7 +290,7 @@ public: } DEVICE_INLINE void - storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + storeGlobalData32(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + @@ -309,7 +309,7 @@ public: } DEVICE_INLINE void storeGlobalData32ColumnBatch( - E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * @@ -326,7 +326,7 @@ public: } DEVICE_INLINE void - loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + loadGlobalData16(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + @@ -345,7 +345,7 @@ public: } DEVICE_INLINE void loadGlobalData16ColumnBatch( - const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * @@ -362,7 +362,7 @@ public: } DEVICE_INLINE void - storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + storeGlobalData16(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + @@ -381,7 +381,7 @@ public: } DEVICE_INLINE void storeGlobalData16ColumnBatch( - E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *