Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ntt mixed-radix bug regarding large ntts (and/or batch) #523

Merged
merged 3 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions icicle/src/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ namespace mxntt {
// if its index is the smallest number in the group -> do the memory transformation
// else --> do nothing

const uint32_t size = 1 << log_size;
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t idx = columns_batch ? tid / batch_size : tid % size;
const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size;
if (tid >= size * batch_size) return;

uint32_t next_element = idx;
uint32_t group[MAX_GROUP_SIZE];
const uint64_t size = 1UL << log_size;
const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const uint64_t idx = columns_batch ? tid / batch_size : tid % size;
const uint64_t batch_idx = columns_batch ? tid % batch_size : tid / size;
if (tid >= uint64_t(size) * batch_size) return;

uint64_t next_element = idx;
uint64_t group[MAX_GROUP_SIZE];
group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;

uint32_t i = 1;
Expand Down Expand Up @@ -114,11 +114,13 @@ namespace mxntt {
bool is_normalize,
S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= (1 << log_size) * batch_size) return;
uint32_t rd = tid;
uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
const uint64_t size = 1UL << log_size;
const uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x;
if (tid >= uint64_t(size) * batch_size) return;

uint64_t rd = tid;
uint64_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
generalized_rev((tid / columns_batch_size) & (size - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}

Expand All @@ -131,14 +133,14 @@ namespace mxntt {
uint32_t columns_batch_size,
S* scalar_vec,
int step,
int n_scalars,
uint32_t n_scalars,
uint32_t log_size,
eRevType rev_type,
bool fast_tw,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= size * batch_size) return;
uint64_t tid = uint64_t(blockDim.x) * blockIdx.x + threadIdx.x;
if (tid >= uint64_t(size) * batch_size) return;
int64_t scalar_id = (tid / columns_batch_size) % size;
if (rev_type != eRevType::None) {
// Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the
Expand All @@ -148,8 +150,7 @@ namespace mxntt {
// Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the
// corresponding element in scalars vec.
const bool dif = rev_type == eRevType::NaturalToMixedRev;
scalar_id =
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type);
scalar_id = generalized_rev((tid / columns_batch_size) & (size - 1), log_size, !dif, fast_tw, rev_type);
}
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}
Expand Down Expand Up @@ -523,9 +524,9 @@ namespace mxntt {
}

template <typename E, typename S>
__global__ void normalize_kernel(E* data, S norm_factor, uint32_t size)
__global__ void normalize_kernel(E* data, S norm_factor, uint64_t size)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
uint64_t tid = uint64_t(blockIdx.x) * blockDim.x + threadIdx.x;
if (tid >= size) return;
data[tid] = data[tid] * norm_factor;
}
Expand Down Expand Up @@ -786,7 +787,7 @@ namespace mxntt {
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize)
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1UL << log_size) * batch_size);
return CHK_LAST();
}

Expand All @@ -804,7 +805,7 @@ namespace mxntt {
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize)
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1UL << log_size) * batch_size);
return CHK_LAST();
}

Expand All @@ -816,7 +817,7 @@ namespace mxntt {
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
if (normalize)
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1UL << log_size) * batch_size);
return CHK_LAST();
}

Expand Down Expand Up @@ -844,12 +845,12 @@ namespace mxntt {
columns_batch, 0, inv, dit, fast_tw);
}
if (normalize)
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1 << log_size) * batch_size);
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1UL << log_size) * batch_size);
return CHK_LAST();
}

// general case:
uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
uint32_t nof_blocks = (1UL << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
Expand Down Expand Up @@ -900,7 +901,7 @@ namespace mxntt {
}
if (normalize)
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(
out, S::inv_log_size(log_size), (1 << log_size) * batch_size);
out, S::inv_log_size(log_size), (1UL << log_size) * batch_size);

return CHK_LAST();
}
Expand All @@ -926,13 +927,18 @@ namespace mxntt {
{
CHK_INIT_IF_RETURN();

const int logn = int(log2(ntt_size));
const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64;
const int NOF_THREADS = min(64, (1 << logn) * batch_size);
const uint64_t logn = uint64_t(log2(ntt_size));
const uint64_t NOF_BLOCKS_64b = (uint64_t(ntt_size) * batch_size + 64 - 1) / 64;
const uint32_t NOF_THREADS = min(64UL, uint64_t(ntt_size) * batch_size);
// CUDA grid is 32b fields. Assert that I don't need a larger grid.
const uint32_t NOF_BLOCKS = NOF_BLOCKS_64b;
if (NOF_BLOCKS != NOF_BLOCKS_64b) {
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "NTT dimensions (ntt_size, batch) are too large. Unsupported!");
}

bool is_normalize = is_inverse;
const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset;
const int n_twiddles = 1 << max_logn;
const uint32_t n_twiddles = 1U << max_logn;
// Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication
eRevType reverse_input = None, reverse_output = None, reverse_coset = None;
bool dit = false;
Expand Down
Loading
Loading