60
60
#include < iostream>
61
61
#endif
62
62
63
+ namespace {
64
+ // GetBigModulus() calculates the big modulus as the product of
65
+ // the "compositeDegree" number of parameter modulus
66
+ double GetBigModulus (const std::shared_ptr<lbcrypto::CryptoParametersCKKSRNS> cryptoParams) {
67
+ double qDouble = 1.0 ;
68
+ uint32_t compositeDegree = cryptoParams->GetCompositeDegree ();
69
+ for (uint32_t j = 0 ; j < compositeDegree; ++j) {
70
+ qDouble *= cryptoParams->GetElementParams ()->GetParams ()[j]->GetModulus ().ConvertToDouble ();
71
+ }
72
+
73
+ return qDouble;
74
+ }
75
+ } // namespace
63
76
namespace lbcrypto {
64
77
65
78
// ------------------------------------------------------------------------------
@@ -167,17 +180,12 @@ void FHECKKSRNS::EvalBootstrapSetup(const CryptoContextImpl<DCRTPoly>& cc, std::
167
180
uint32_t compositeDegree = cryptoParams->GetCompositeDegree ();
168
181
169
182
// Extract the modulus prior to bootstrapping
170
- NativeInteger q = cryptoParams->GetElementParams ()->GetParams ()[0 ]->GetModulus ().ConvertToInt ();
171
- double qDouble = q.ConvertToDouble ();
172
- for (uint32_t j = 1 ; j < compositeDegree; ++j) {
173
- NativeInteger qj = cryptoParams->GetElementParams ()->GetParams ()[j]->GetModulus ().ConvertToInt ();
174
- qDouble *= qj.ConvertToDouble ();
175
- }
183
+ double qDouble = GetBigModulus (cryptoParams);
176
184
177
- uint128_t factor = (( uint128_t ) 1 << (static_cast <uint32_t >(std::round (std::log2 (qDouble)))));
185
+ uint128_t factor = (static_cast < uint128_t >( 1 ) << (static_cast <uint32_t >(std::round (std::log2 (qDouble)))));
178
186
double pre = (compositeDegree > 1 ) ? 1.0 : qDouble / factor;
179
187
double k = (cryptoParams->GetSecretKeyDist () == SPARSE_TERNARY) ? K_SPARSE : 1.0 ;
180
- double scaleEnc = (compositeDegree > 1 ) ? 1.0 / k : pre / k;
188
+ double scaleEnc = pre / k;
181
189
double scaleDec = (compositeDegree > 1 ) ? qDouble / cryptoParams->GetScalingFactorReal (0 ) : 1 / pre;
182
190
183
191
uint32_t approxModDepth = GetModDepthInternal (cryptoParams->GetSecretKeyDist ());
@@ -299,12 +307,7 @@ void FHECKKSRNS::EvalBootstrapPrecompute(const CryptoContextImpl<DCRTPoly>& cc,
299
307
uint32_t compositeDegree = cryptoParams->GetCompositeDegree ();
300
308
301
309
// Extract the modulus prior to bootstrapping
302
- NativeInteger q = cryptoParams->GetElementParams ()->GetParams ()[0 ]->GetModulus ().ConvertToInt ();
303
- double qDouble = q.ConvertToDouble ();
304
- for (size_t j = 1 ; j < compositeDegree; ++j) {
305
- NativeInteger qj = cryptoParams->GetElementParams ()->GetParams ()[j]->GetModulus ().ConvertToInt ();
306
- qDouble *= qj.ConvertToDouble ();
307
- }
310
+ double qDouble = GetBigModulus (cryptoParams);
308
311
309
312
uint128_t factor = (static_cast <uint128_t >(1 ) << (static_cast <uint32_t >(std::round (std::log2 (qDouble)))));
310
313
double pre = qDouble / factor;
@@ -466,12 +469,7 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
466
469
}
467
470
auto elementParamsRaisedPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);
468
471
469
- NativeInteger q = elementParamsRaisedPtr->GetParams ()[0 ]->GetModulus ().ConvertToInt ();
470
- double qDouble = q.ConvertToDouble ();
471
- for (uint32_t j = 1 ; j < compositeDegree; ++j) {
472
- NativeInteger qj = elementParamsRaisedPtr->GetParams ()[j]->GetModulus ().ConvertToInt ();
473
- qDouble *= qj.ConvertToDouble ();
474
- }
472
+ double qDouble = GetBigModulus (cryptoParams);
475
473
476
474
const auto p = cryptoParams->GetPlaintextModulus ();
477
475
double powP = pow (2 , p);
@@ -536,24 +534,24 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
536
534
qhat_inv_modqj[j] = qhat_modqj[j].ModInverse (qj[j]);
537
535
}
538
536
537
+ NativeInteger qjProduct =
538
+ std::accumulate (qj.begin (), qj.end (), NativeInteger{1 }, std::multiplies<NativeInteger>());
539
539
uint32_t init_element_index = compositeDegree;
540
540
for (size_t i = 0 ; i < ctxtDCRT.size (); i++) {
541
541
std::vector<DCRTPoly> temp (compositeDegree + 1 , DCRTPoly (elementParamsRaisedPtr, COEFFICIENT));
542
542
std::vector<DCRTPoly> ctxtDCRT_modq (compositeDegree, DCRTPoly (elementParamsRaisedPtr, COEFFICIENT));
543
543
544
544
ctxtDCRT[i].SetFormat (COEFFICIENT);
545
-
546
545
for (size_t j = 0 ; j < ctxtDCRT[i].GetNumOfElements (); j++) {
547
546
for (size_t k = 0 ; k < compositeDegree; k++)
548
547
ctxtDCRT_modq[k].SetElementAtIndex (j, ctxtDCRT[i].GetElementAtIndex (j) * qhat_inv_modqj[k]);
549
548
}
550
-
549
+ // =========================================================================================================
551
550
temp[0 ] = ctxtDCRT_modq[0 ].GetElementAtIndex (0 );
552
- for (size_t j = 0 ; j < elementParamsRaisedPtr->GetParams ().size (); j++) {
553
- for (size_t k = 1 ; k < compositeDegree; k++)
554
- temp[0 ].SetElementAtIndex (j, temp[0 ].GetElementAtIndex (j) * qj[k]);
551
+ for (auto & el : temp[0 ].GetAllElements ()) {
552
+ el *= qjProduct;
555
553
}
556
-
554
+ // =========================================================================================================
557
555
for (size_t d = 1 ; d < compositeDegree; d++) {
558
556
temp[init_element_index] = ctxtDCRT_modq[d].GetElementAtIndex (d);
559
557
@@ -562,22 +560,23 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
562
560
temp[d].SetElementAtIndex (k, temp[0 ].GetElementAtIndex (k) * qj[k]);
563
561
}
564
562
}
563
+ // =========================================================================================================
564
+ NativeInteger qjProductD{1 };
565
+ for (size_t k = 0 ; k < compositeDegree; k++) {
566
+ if (k != d)
567
+ qjProductD *= qj[k];
568
+ }
565
569
566
570
for (size_t j = compositeDegree; j < elementParamsRaisedPtr->GetParams ().size (); j++) {
567
- temp[d].SetElementAtIndex (j, temp[init_element_index].GetElementAtIndex (j) * qj[0 ]);
568
- for (size_t k = 1 ; k < compositeDegree; k++) {
569
- if (k != d) {
570
- temp[d].SetElementAtIndex (j, temp[d].GetElementAtIndex (j) * qj[k]);
571
- }
572
- }
571
+ auto value = temp[init_element_index].GetElementAtIndex (j) * qjProductD;
572
+ temp[d].SetElementAtIndex (j, value);
573
573
}
574
- temp[d].SetElementAtIndex (d, temp[init_element_index].GetElementAtIndex (d) * qj[0 ]);
575
- for (size_t k = 1 ; k < compositeDegree; k++) {
576
- if (k != d) {
577
- temp[d].SetElementAtIndex (d, temp[d].GetElementAtIndex (d) * qj[k]);
578
- }
574
+ // =========================================================================================================
575
+ {
576
+ auto value = temp[init_element_index].GetElementAtIndex (d) * qjProductD;
577
+ temp[d].SetElementAtIndex (d, value);
579
578
}
580
-
579
+ // =========================================================================================================
581
580
temp[0 ] += temp[d];
582
581
}
583
582
@@ -2557,34 +2556,19 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
2557
2556
moduli[i] = nativeParams[i]->GetModulus ();
2558
2557
}
2559
2558
2560
- DCRTPoly::Integer intPowP{static_cast <uint64_t >(std::llround (powP))};
2561
- std::vector<DCRTPoly::Integer> crtPowP (numTowers, intPowP);
2562
-
2559
+ std::vector<DCRTPoly::Integer> crtPowP;
2563
2560
if (cryptoParams->GetScalingTechnique () == COMPOSITESCALINGAUTO ||
2564
2561
cryptoParams->GetScalingTechnique () == COMPOSITESCALINGMANUAL) {
2565
2562
// Duhyeong: Support the case powP > 2^64
2566
2563
// Later we might need to use the NATIVE_INT=128 version of FHECKKSRNS::MakeAuxPlaintext for higher precision
2567
2564
int32_t logPowP = static_cast <int32_t >(ceil (log2 (fabs (powP))));
2568
- // DCRTPoly::Integer intPowP;
2569
- int32_t logApprox_PowP;
2565
+
2570
2566
if (logPowP > 64 ) {
2571
2567
// Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
2572
2568
logValid = (logPowP <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ?
2573
2569
logPowP :
2574
2570
LargeScalingFactorConstants::MAX_BITS_IN_WORD;
2575
- logApprox_PowP = logPowP - logValid;
2576
- approxFactor = pow (2 , logApprox_PowP);
2577
- // Multiply scFactor in two steps: powP / approxFactor and then approxFactor
2578
- intPowP = std::llround (powP / approxFactor);
2579
- }
2580
- else {
2581
- intPowP = std::llround (powP);
2582
- }
2583
-
2584
- // std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);
2585
- crtPowP.resize (numTowers, intPowP);
2586
-
2587
- if (logPowP > 64 ) {
2571
+ int32_t logApprox_PowP = logPowP - logValid;
2588
2572
if (logApprox_PowP > 0 ) {
2589
2573
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
2590
2574
logApprox_PowP :
@@ -2603,7 +2587,20 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
2603
2587
}
2604
2588
crtPowP = CKKSPackedEncoding::CRTMult (crtPowP, crtApprox, moduli);
2605
2589
}
2590
+ else {
2591
+ double approxFactor = pow (2 , logApprox_PowP);
2592
+ DCRTPoly::Integer intPowP{static_cast <uint64_t >(std::llround (powP / approxFactor))};
2593
+ crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
2594
+ }
2606
2595
}
2596
+ else {
2597
+ DCRTPoly::Integer intPowP{static_cast <uint64_t >(std::llround (powP))};
2598
+ crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
2599
+ }
2600
+ }
2601
+ else {
2602
+ DCRTPoly::Integer intPowP{static_cast <uint64_t >(std::llround (powP))};
2603
+ crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
2607
2604
}
2608
2605
2609
2606
auto currPowP = crtPowP;
0 commit comments