Skip to content

Commit

Permalink
fix: ntt mixed-radix bug regarding large ntts (and/or batch)
Browse files Browse the repository at this point in the history
in some cases 32b values would wrap around and cause invalid accesses to wrong elements and memory addresses
  • Loading branch information
yshekel committed May 19, 2024
1 parent 1e343f1 commit 5d6b098
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 43 deletions.
66 changes: 36 additions & 30 deletions icicle/src/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ namespace mxntt {
// if its index is the smallest number in the group -> do the memory transformation
// else --> do nothing

const uint32_t size = 1 << log_size;
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t idx = columns_batch ? tid / batch_size : tid % size;
const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size;
if (tid >= size * batch_size) return;

uint32_t next_element = idx;
uint32_t group[MAX_GROUP_SIZE];
const uint64_t size = 1UL << log_size;
const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const uint64_t idx = columns_batch ? tid / batch_size : tid % size;
const uint64_t batch_idx = columns_batch ? tid % batch_size : tid / size;
if (tid >= uint64_t(size) * batch_size) return;

uint64_t next_element = idx;
uint64_t group[MAX_GROUP_SIZE];
group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;

uint32_t i = 1;
Expand Down Expand Up @@ -114,11 +114,13 @@ namespace mxntt {
bool is_normalize,
S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= (1 << log_size) * batch_size) return;
uint32_t rd = tid;
uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
const uint64_t size = 1UL << log_size;
const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x;
if (tid >= uint64_t(size) * batch_size) return;

uint64_t rd = tid;
uint64_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
generalized_rev((tid / columns_batch_size) & (size - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}

Expand All @@ -131,14 +133,14 @@ namespace mxntt {
uint32_t columns_batch_size,
S* scalar_vec,
int step,
int n_scalars,
uint32_t n_scalars,
uint32_t log_size,
eRevType rev_type,
bool fast_tw,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= size * batch_size) return;
uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x;
if (tid >= uint64_t(size) * batch_size) return;
int64_t scalar_id = (tid / columns_batch_size) % size;
if (rev_type != eRevType::None) {
// Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the
Expand All @@ -148,8 +150,7 @@ namespace mxntt {
// Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the
// corresponding element in scalars vec.
const bool dif = rev_type == eRevType::NaturalToMixedRev;
scalar_id =
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type);
scalar_id = generalized_rev((tid / columns_batch_size) & (size - 1), log_size, !dif, fast_tw, rev_type);
}
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}
Expand Down Expand Up @@ -523,9 +524,9 @@ namespace mxntt {
}

template <typename E, typename S>
__global__ void normalize_kernel(E* data, S norm_factor, uint32_t size)
__global__ void normalize_kernel(E* data, S norm_factor, uint64_t size)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
uint64_t tid = uint64_t(blockIdx.x) * blockDim.x + threadIdx.x;
if (tid >= size) return;
data[tid] = data[tid] * norm_factor;
}
Expand Down Expand Up @@ -786,7 +787,7 @@ namespace mxntt {
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize)
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1UL << log_size) * batch_size);
return CHK_LAST();
}

Expand All @@ -804,7 +805,7 @@ namespace mxntt {
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize)
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1UL << log_size) * batch_size);
return CHK_LAST();
}

Expand All @@ -816,7 +817,7 @@ namespace mxntt {
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
if (normalize)
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1UL << log_size) * batch_size);
return CHK_LAST();
}

Expand Down Expand Up @@ -844,12 +845,12 @@ namespace mxntt {
columns_batch, 0, inv, dit, fast_tw);
}
if (normalize)
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1UL << log_size) * batch_size);
return CHK_LAST();
}

// general case:
uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
uint32_t nof_blocks = (1UL << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
Expand Down Expand Up @@ -900,7 +901,7 @@ namespace mxntt {
}
if (normalize)
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(
out, S::inv_log_size(log_size), (1 << log_size) * batch_size);
out, S::inv_log_size(log_size), (1UL << log_size) * batch_size);

return CHK_LAST();
}
Expand All @@ -926,13 +927,18 @@ namespace mxntt {
{
CHK_INIT_IF_RETURN();

const int logn = int(log2(ntt_size));
const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64;
const int NOF_THREADS = min(64, (1 << logn) * batch_size);
const uint64_t logn = uint64_t(log2(ntt_size));
const uint64_t NOF_BLOCKS_64b = (uint64_t(ntt_size) * batch_size + 64 - 1) / 64;
const uint32_t NOF_THREADS = min(64UL, uint64_t(ntt_size) * batch_size);
// CUDA grid is 32b fields. Assert that I don't need a larger grid.
const uint32_t NOF_BLOCKS = NOF_BLOCKS_64b;
if (NOF_BLOCKS != NOF_BLOCKS_64b) {
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "NTT dimensions (ntt_size, batch) are too large. Unsupported!");
}

bool is_normalize = is_inverse;
const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset;
const int n_twiddles = 1 << max_logn;
const uint32_t n_twiddles = 1U << max_logn;
// Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication
eRevType reverse_input = None, reverse_output = None, reverse_coset = None;
bool dit = false;
Expand Down
26 changes: 13 additions & 13 deletions icicle/src/ntt/thread_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct stage_metadata {
uint32_t th_stride;
uint32_t ntt_block_size;
uint32_t batch_id;
uint32_t ntt_block_id;
uint32_t ntt_inp_id;
uint64_t ntt_block_id;
};

#define STAGE_SIZES_DATA \
Expand Down Expand Up @@ -194,7 +194,7 @@ public:
}

DEVICE_INLINE void
loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
loadGlobalData(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
Expand All @@ -210,7 +210,7 @@ public:
}

DEVICE_INLINE void loadGlobalDataColumnBatch(
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
Expand All @@ -224,7 +224,7 @@ public:
}

DEVICE_INLINE void
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
storeGlobalData(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
Expand All @@ -240,7 +240,7 @@ public:
}

DEVICE_INLINE void storeGlobalDataColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
Expand All @@ -254,7 +254,7 @@ public:
}

DEVICE_INLINE void
loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
loadGlobalData32(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
Expand All @@ -273,7 +273,7 @@ public:
}

DEVICE_INLINE void loadGlobalData32ColumnBatch(
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
Expand All @@ -290,7 +290,7 @@ public:
}

DEVICE_INLINE void
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
storeGlobalData32(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
Expand All @@ -309,7 +309,7 @@ public:
}

DEVICE_INLINE void storeGlobalData32ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
Expand All @@ -326,7 +326,7 @@ public:
}

DEVICE_INLINE void
loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
loadGlobalData16(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
Expand All @@ -345,7 +345,7 @@ public:
}

DEVICE_INLINE void loadGlobalData16ColumnBatch(
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
Expand All @@ -362,7 +362,7 @@ public:
}

DEVICE_INLINE void
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
storeGlobalData16(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
Expand All @@ -381,7 +381,7 @@ public:
}

DEVICE_INLINE void storeGlobalData16ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
Expand Down

0 comments on commit 5d6b098

Please sign in to comment.