Skip to content

Commit

Permalink
fix(gpu): fix memory error in shift and rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Feb 13, 2025
1 parent 987d5bf commit b20d511
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2781,9 +2781,6 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
tmp_rotated = pre_allocated_buffer;
reuse_memory = true;

uint32_t max_amount_of_pbs = num_radix_blocks;
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);
set_zero_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
tmp_rotated, 0,
tmp_rotated->num_radix_blocks);
Expand Down
13 changes: 12 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2161,10 +2161,21 @@ extract_n_bits(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array_out,
const CudaRadixCiphertextFFI *lwe_array_in, void *const *bsks,
Torus *const *ksks, uint32_t effective_num_radix_blocks,
uint32_t num_radix_blocks,
int_bit_extract_luts_buffer<Torus> *bit_extract) {

copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
lwe_array_out, 0, num_radix_blocks,
lwe_array_in, 0, num_radix_blocks);
if (effective_num_radix_blocks / num_radix_blocks > 0) {
for (uint i = 1; i < effective_num_radix_blocks / num_radix_blocks; i++) {
copy_radix_ciphertext_slice_async<Torus>(
streams[0], gpu_indexes[0], lwe_array_out, i * num_radix_blocks,
(i + 1) * num_radix_blocks, lwe_array_in, 0, num_radix_blocks);
}
}
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, bsks, ksks,
streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_out, bsks, ksks,
bit_extract->lut, effective_num_radix_blocks);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
auto bits = mem->tmp_bits;
extract_n_bits<Torus>(streams, gpu_indexes, gpu_count, bits, lwe_array, bsks,
ksks, num_radix_blocks * bits_per_block,
mem->bit_extract_luts);
num_radix_blocks, mem->bit_extract_luts);

// Extract shift bits
auto shift_bits = mem->tmp_shift_bits;
Expand All @@ -78,7 +78,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
// and we reduce noise growth
extract_n_bits<Torus>(streams, gpu_indexes, gpu_count, shift_bits, lwe_shift,
bsks, ksks, max_num_bits_that_tell_shift,
mem->bit_extract_luts_with_offset_2);
num_radix_blocks, mem->bit_extract_luts_with_offset_2);

// If signed, do an "arithmetic shift" by padding with the sign bit
CudaRadixCiphertextFFI last_bit;
Expand Down

0 comments on commit b20d511

Please sign in to comment.