Skip to content

Commit

Permalink
Fix: div by 0 when number of Elements is 1 (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfelder authored Sep 28, 2023
1 parent 9f67075 commit 97f0079
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions icicle/appUtils/ntt/ntt.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ void ntt_inplace_batch_template(
const int logn = int(log(n) / log(2));
bool is_shared_mem_enabled = sizeof(E) <= MAX_SHARED_MEM_ELEMENT_SIZE;
const int log2_shmem_elems = is_shared_mem_enabled ? int(log(int(MAX_SHARED_MEM / sizeof(E))) / log(2)) : logn;
int num_threads = min(min(n / 2, MAX_THREADS_BATCH), 1 << (log2_shmem_elems - 1));
int num_threads = max(min(min(n / 2, MAX_THREADS_BATCH), 1 << (log2_shmem_elems - 1)), 1);
const int chunks = max(int((n / 2) / num_threads), 1);
const int total_tasks = batch_size * chunks;
int num_blocks = total_tasks;
Expand All @@ -328,7 +328,7 @@ void ntt_inplace_batch_template(

if (is_coset) batch_vector_mult(coset, d_inout, n, batch_size, stream);

num_threads = min(n / 2, MAX_NUM_THREADS);
num_threads = max(min(n / 2, MAX_NUM_THREADS), 1);
num_blocks = (n * batch_size + num_threads - 1) / num_threads;
template_normalize_kernel<E, S>
<<<num_blocks, num_threads, 0, stream>>>(d_inout, n * batch_size, S::inv_log_size(logn));
Expand Down

0 comments on commit 97f0079

Please sign in to comment.