diff --git a/icicle/src/ntt/kernel_ntt.cu b/icicle/src/ntt/kernel_ntt.cu index 3166b334c..46e97243f 100644 --- a/icicle/src/ntt/kernel_ntt.cu +++ b/icicle/src/ntt/kernel_ntt.cu @@ -73,14 +73,14 @@ namespace mxntt { // if its index is the smallest number in the group -> do the memory transformation // else --> do nothing - const uint32_t size = 1 << log_size; - const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; - const uint32_t idx = columns_batch ? tid / batch_size : tid % size; - const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size; - if (tid >= size * batch_size) return; - - uint32_t next_element = idx; - uint32_t group[MAX_GROUP_SIZE]; + const uint64_t size = 1UL << log_size; + const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const uint64_t idx = columns_batch ? tid / batch_size : tid % size; + const uint64_t batch_idx = columns_batch ? tid % batch_size : tid / size; + if (tid >= uint64_t(size) * batch_size) return; + + uint64_t next_element = idx; + uint64_t group[MAX_GROUP_SIZE]; group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx; uint32_t i = 1; @@ -114,11 +114,13 @@ namespace mxntt { bool is_normalize, S inverse_N) { - uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= (1 << log_size) * batch_size) return; - uint32_t rd = tid; - uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) + - generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type); + const uint64_t size = 1UL << log_size; + const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x; + if (tid >= uint64_t(size) * batch_size) return; + + uint64_t rd = tid; + uint64_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) + + generalized_rev((tid / columns_batch_size) & (size - 1), log_size, dit, fast_tw, rev_type); arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd]; } @@ -131,14 +133,14 @@ namespace mxntt { uint32_t columns_batch_size, S* scalar_vec, int step, - int n_scalars, + uint32_t n_scalars, uint32_t log_size, eRevType rev_type, bool fast_tw, E* out_vec) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= size * batch_size) return; + uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x; + if (tid >= uint64_t(size) * batch_size) return; int64_t scalar_id = (tid / columns_batch_size) % size; if (rev_type != eRevType::None) { // Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the @@ -148,8 +150,7 @@ namespace mxntt { // Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the // corresponding element in scalars vec. const bool dif = rev_type == eRevType::NaturalToMixedRev; - scalar_id = - generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type); + scalar_id = generalized_rev((tid / columns_batch_size) & (size - 1), log_size, !dif, fast_tw, rev_type); } out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid]; } @@ -523,9 +524,9 @@ namespace mxntt { } template - __global__ void normalize_kernel(E* data, S norm_factor, uint32_t size) + __global__ void normalize_kernel(E* data, S norm_factor, uint64_t size) { - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint64_t tid = uint64_t(blockIdx.x) * blockDim.x + threadIdx.x; if (tid >= size) return; data[tid] = data[tid] * norm_factor; } @@ -786,7 +787,7 @@ namespace mxntt { columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(4), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(4), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -804,7 +805,7 @@ namespace mxntt { columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(5), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(5), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -816,7 +817,7 @@ namespace mxntt { in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(6), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(6), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -844,12 +845,12 @@ namespace mxntt { columns_batch, 0, inv, dit, fast_tw); } if (normalize) - normalize_kernel<<>>(out, S::inv_log_size(8), (1 << log_size) * batch_size); + normalize_kernel<<>>(out, S::inv_log_size(8), (1UL << log_size) * batch_size); return CHK_LAST(); } // general case: - uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size); + uint32_t nof_blocks = (1UL << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size); if (dit) { for (int i = 0; i < 5; i++) { uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i]; @@ -900,7 +901,7 @@ namespace mxntt { } if (normalize) normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>( - out, S::inv_log_size(log_size), (1 << log_size) * batch_size); + out, S::inv_log_size(log_size), (1UL << log_size) * batch_size); return CHK_LAST(); } @@ -926,13 +927,19 @@ namespace mxntt { { CHK_INIT_IF_RETURN(); - const int logn = int(log2(ntt_size)); - const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64; - const int NOF_THREADS = min(64, (1 << logn) * batch_size); + const uint64_t total_nof_elements = uint64_t(ntt_size) * batch_size; + const uint64_t logn = uint64_t(log2(ntt_size)); + const uint64_t NOF_BLOCKS_64b = (total_nof_elements + 64 - 1) / 64; + const uint32_t NOF_THREADS = total_nof_elements < 64 ? total_nof_elements : 64; + // CUDA grid is 32b fields. Assert that I don't need a larger grid. + const uint32_t NOF_BLOCKS = NOF_BLOCKS_64b; + if (NOF_BLOCKS != NOF_BLOCKS_64b) { + THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "NTT dimensions (ntt_size, batch) are too large. Unsupported!"); + } bool is_normalize = is_inverse; const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset; - const int n_twiddles = 1 << max_logn; + const uint32_t n_twiddles = 1U << max_logn; // Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication eRevType reverse_input = None, reverse_output = None, reverse_coset = None; bool dit = false; diff --git a/icicle/src/ntt/thread_ntt.cu b/icicle/src/ntt/thread_ntt.cu index 8b3e4a56d..8fd2764df 100644 --- a/icicle/src/ntt/thread_ntt.cu +++ b/icicle/src/ntt/thread_ntt.cu @@ -196,78 +196,83 @@ public: DEVICE_INLINE void loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { + const uint64_t data_stride_u64 = data_stride; if (strided) { - data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; } else { - data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; } UNROLL for (uint32_t i = 0; i < 8; i++) { - X[i] = data[s_meta.th_stride * i * data_stride]; + X[i] = data[s_meta.th_stride * i * data_stride_u64]; } } DEVICE_INLINE void loadGlobalDataColumnBatch( const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { - data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + const uint64_t data_stride_u64 = data_stride; + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) * batch_size + s_meta.batch_id; UNROLL for (uint32_t i = 0; i < 8; i++) { - X[i] = data[s_meta.th_stride * i * data_stride * batch_size]; + X[i] = data[s_meta.th_stride * i * data_stride_u64 * batch_size]; } } DEVICE_INLINE void storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { + const uint64_t data_stride_u64 = data_stride; if (strided) { - data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; } else { - data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; } UNROLL for (uint32_t i = 0; i < 8; i++) { - data[s_meta.th_stride * i * data_stride] = X[i]; + data[s_meta.th_stride * i * data_stride_u64] = X[i]; } } DEVICE_INLINE void storeGlobalDataColumnBatch( E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { - data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + const uint64_t data_stride_u64 = data_stride; + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) * batch_size + s_meta.batch_id; UNROLL for (uint32_t i = 0; i < 8; i++) { - data[s_meta.th_stride * i * data_stride * batch_size] = X[i]; + data[s_meta.th_stride * i * data_stride_u64 * batch_size] = X[i]; } } DEVICE_INLINE void loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { + const uint64_t data_stride_u64 = data_stride; if (strided) { - data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; } else { - data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2; + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2; } UNROLL for (uint32_t j = 0; j < 2; j++) { UNROLL for (uint32_t i = 0; i < 4; i++) { - X[4 * j + i] = data[(8 * i + j) * data_stride]; + X[4 * j + i] = data[(8 * i + j) * data_stride_u64]; } } } @@ -275,8 +280,9 @@ public: DEVICE_INLINE void loadGlobalData32ColumnBatch( const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { - data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + const uint64_t data_stride_u64 = data_stride; + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) * batch_size + s_meta.batch_id; @@ -284,7 +290,7 @@ public: for (uint32_t j = 0; j < 2; j++) { UNROLL for (uint32_t i = 0; i < 4; i++) { - X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size]; + X[4 * j + i] = data[(8 * i + j) * data_stride_u64 * batch_size]; } } } @@ -292,18 +298,19 @@ public: DEVICE_INLINE void storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { + const uint64_t data_stride_u64 = data_stride; if (strided) { - data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; } else { - data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2; + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2; } UNROLL for (uint32_t j = 0; j < 2; j++) { UNROLL for (uint32_t i = 0; i < 4; i++) { - data[(8 * i + j) * data_stride] = X[4 * j + i]; + data[(8 * i + j) * data_stride_u64] = X[4 * j + i]; } } } @@ -311,8 +318,9 @@ public: DEVICE_INLINE void storeGlobalData32ColumnBatch( E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { - data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + const uint64_t data_stride_u64 = data_stride; + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) * batch_size + s_meta.batch_id; @@ -320,7 +328,7 @@ public: for (uint32_t j = 0; j < 2; j++) { UNROLL for (uint32_t i = 0; i < 4; i++) { - data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i]; + data[(8 * i + j) * data_stride_u64 * batch_size] = X[4 * j + i]; } } } @@ -328,18 +336,19 @@ public: DEVICE_INLINE void loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { + const uint64_t data_stride_u64 = data_stride; if (strided) { - data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; } else { - data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4; + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4; } UNROLL for (uint32_t j = 0; j < 4; j++) { UNROLL for (uint32_t i = 0; i < 2; i++) { - X[2 * j + i] = data[(8 * i + j) * data_stride]; + X[2 * j + i] = data[(8 * i + j) * data_stride_u64]; } } } @@ -347,8 +356,9 @@ public: DEVICE_INLINE void loadGlobalData16ColumnBatch( const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { - data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + const uint64_t data_stride_u64 = data_stride; + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) * batch_size + s_meta.batch_id; @@ -356,7 +366,7 @@ public: for (uint32_t j = 0; j < 4; j++) { UNROLL for (uint32_t i = 0; i < 2; i++) { - X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size]; + X[2 * j + i] = data[(8 * i + j) * data_stride_u64 * batch_size]; } } } @@ -364,18 +374,19 @@ public: DEVICE_INLINE void storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { + const uint64_t data_stride_u64 = data_stride; if (strided) { - data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; } else { - data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4; + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4; } UNROLL for (uint32_t j = 0; j < 4; j++) { UNROLL for (uint32_t i = 0; i < 2; i++) { - data[(8 * i + j) * data_stride] = X[2 * j + i]; + data[(8 * i + j) * data_stride_u64] = X[2 * j + i]; } } } @@ -383,8 +394,9 @@ public: DEVICE_INLINE void storeGlobalData16ColumnBatch( E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) { - data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + - (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + const uint64_t data_stride_u64 = data_stride; + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) * batch_size + s_meta.batch_id; @@ -392,7 +404,7 @@ public: for (uint32_t j = 0; j < 4; j++) { UNROLL for (uint32_t i = 0; i < 2; i++) { - data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i]; + data[(8 * i + j) * data_stride_u64 * batch_size] = X[2 * j + i]; } } }