Skip to content

Commit e54a469

Browse files
Code improvements
1 parent 28e7a1a commit e54a469

File tree

1 file changed

+53
-56
lines changed

1 file changed

+53
-56
lines changed

src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp

+53-56
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@
6060
#include <iostream>
6161
#endif
6262

63+
namespace {
64+
// GetBigModulus() calculates the big modulus as the product of
65+
// the "compositeDegree" number of parameter modulus
66+
double GetBigModulus(const std::shared_ptr<lbcrypto::CryptoParametersCKKSRNS> cryptoParams) {
67+
double qDouble = 1.0;
68+
uint32_t compositeDegree = cryptoParams->GetCompositeDegree();
69+
for (uint32_t j = 0; j < compositeDegree; ++j) {
70+
qDouble *= cryptoParams->GetElementParams()->GetParams()[j]->GetModulus().ConvertToDouble();
71+
}
72+
73+
return qDouble;
74+
}
75+
} // namespace
6376
namespace lbcrypto {
6477

6578
//------------------------------------------------------------------------------
@@ -167,17 +180,12 @@ void FHECKKSRNS::EvalBootstrapSetup(const CryptoContextImpl<DCRTPoly>& cc, std::
167180
uint32_t compositeDegree = cryptoParams->GetCompositeDegree();
168181

169182
// Extract the modulus prior to bootstrapping
170-
NativeInteger q = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus().ConvertToInt();
171-
double qDouble = q.ConvertToDouble();
172-
for (uint32_t j = 1; j < compositeDegree; ++j) {
173-
NativeInteger qj = cryptoParams->GetElementParams()->GetParams()[j]->GetModulus().ConvertToInt();
174-
qDouble *= qj.ConvertToDouble();
175-
}
183+
double qDouble = GetBigModulus(cryptoParams);
176184

177-
uint128_t factor = ((uint128_t)1 << (static_cast<uint32_t>(std::round(std::log2(qDouble)))));
185+
uint128_t factor = (static_cast<uint128_t>(1) << (static_cast<uint32_t>(std::round(std::log2(qDouble)))));
178186
double pre = (compositeDegree > 1) ? 1.0 : qDouble / factor;
179187
double k = (cryptoParams->GetSecretKeyDist() == SPARSE_TERNARY) ? K_SPARSE : 1.0;
180-
double scaleEnc = (compositeDegree > 1) ? 1.0 / k : pre / k;
188+
double scaleEnc = pre / k;
181189
double scaleDec = (compositeDegree > 1) ? qDouble / cryptoParams->GetScalingFactorReal(0) : 1 / pre;
182190

183191
uint32_t approxModDepth = GetModDepthInternal(cryptoParams->GetSecretKeyDist());
@@ -299,12 +307,7 @@ void FHECKKSRNS::EvalBootstrapPrecompute(const CryptoContextImpl<DCRTPoly>& cc,
299307
uint32_t compositeDegree = cryptoParams->GetCompositeDegree();
300308

301309
// Extract the modulus prior to bootstrapping
302-
NativeInteger q = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus().ConvertToInt();
303-
double qDouble = q.ConvertToDouble();
304-
for (size_t j = 1; j < compositeDegree; ++j) {
305-
NativeInteger qj = cryptoParams->GetElementParams()->GetParams()[j]->GetModulus().ConvertToInt();
306-
qDouble *= qj.ConvertToDouble();
307-
}
310+
double qDouble = GetBigModulus(cryptoParams);
308311

309312
uint128_t factor = (static_cast<uint128_t>(1) << (static_cast<uint32_t>(std::round(std::log2(qDouble)))));
310313
double pre = qDouble / factor;
@@ -466,12 +469,7 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
466469
}
467470
auto elementParamsRaisedPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);
468471

469-
NativeInteger q = elementParamsRaisedPtr->GetParams()[0]->GetModulus().ConvertToInt();
470-
double qDouble = q.ConvertToDouble();
471-
for (uint32_t j = 1; j < compositeDegree; ++j) {
472-
NativeInteger qj = elementParamsRaisedPtr->GetParams()[j]->GetModulus().ConvertToInt();
473-
qDouble *= qj.ConvertToDouble();
474-
}
472+
double qDouble = GetBigModulus(cryptoParams);
475473

476474
const auto p = cryptoParams->GetPlaintextModulus();
477475
double powP = pow(2, p);
@@ -536,24 +534,24 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
536534
qhat_inv_modqj[j] = qhat_modqj[j].ModInverse(qj[j]);
537535
}
538536

537+
NativeInteger qjProduct =
538+
std::accumulate(qj.begin(), qj.end(), NativeInteger{1}, std::multiplies<NativeInteger>());
539539
uint32_t init_element_index = compositeDegree;
540540
for (size_t i = 0; i < ctxtDCRT.size(); i++) {
541541
std::vector<DCRTPoly> temp(compositeDegree + 1, DCRTPoly(elementParamsRaisedPtr, COEFFICIENT));
542542
std::vector<DCRTPoly> ctxtDCRT_modq(compositeDegree, DCRTPoly(elementParamsRaisedPtr, COEFFICIENT));
543543

544544
ctxtDCRT[i].SetFormat(COEFFICIENT);
545-
546545
for (size_t j = 0; j < ctxtDCRT[i].GetNumOfElements(); j++) {
547546
for (size_t k = 0; k < compositeDegree; k++)
548547
ctxtDCRT_modq[k].SetElementAtIndex(j, ctxtDCRT[i].GetElementAtIndex(j) * qhat_inv_modqj[k]);
549548
}
550-
549+
//=========================================================================================================
551550
temp[0] = ctxtDCRT_modq[0].GetElementAtIndex(0);
552-
for (size_t j = 0; j < elementParamsRaisedPtr->GetParams().size(); j++) {
553-
for (size_t k = 1; k < compositeDegree; k++)
554-
temp[0].SetElementAtIndex(j, temp[0].GetElementAtIndex(j) * qj[k]);
551+
for (auto& el : temp[0].GetAllElements()) {
552+
el *= qjProduct;
555553
}
556-
554+
//=========================================================================================================
557555
for (size_t d = 1; d < compositeDegree; d++) {
558556
temp[init_element_index] = ctxtDCRT_modq[d].GetElementAtIndex(d);
559557

@@ -562,22 +560,23 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
562560
temp[d].SetElementAtIndex(k, temp[0].GetElementAtIndex(k) * qj[k]);
563561
}
564562
}
563+
//=========================================================================================================
564+
NativeInteger qjProductD{1};
565+
for (size_t k = 0; k < compositeDegree; k++) {
566+
if (k != d)
567+
qjProductD *= qj[k];
568+
}
565569

566570
for (size_t j = compositeDegree; j < elementParamsRaisedPtr->GetParams().size(); j++) {
567-
temp[d].SetElementAtIndex(j, temp[init_element_index].GetElementAtIndex(j) * qj[0]);
568-
for (size_t k = 1; k < compositeDegree; k++) {
569-
if (k != d) {
570-
temp[d].SetElementAtIndex(j, temp[d].GetElementAtIndex(j) * qj[k]);
571-
}
572-
}
571+
auto value = temp[init_element_index].GetElementAtIndex(j) * qjProductD;
572+
temp[d].SetElementAtIndex(j, value);
573573
}
574-
temp[d].SetElementAtIndex(d, temp[init_element_index].GetElementAtIndex(d) * qj[0]);
575-
for (size_t k = 1; k < compositeDegree; k++) {
576-
if (k != d) {
577-
temp[d].SetElementAtIndex(d, temp[d].GetElementAtIndex(d) * qj[k]);
578-
}
574+
//=========================================================================================================
575+
{
576+
auto value = temp[init_element_index].GetElementAtIndex(d) * qjProductD;
577+
temp[d].SetElementAtIndex(d, value);
579578
}
580-
579+
//=========================================================================================================
581580
temp[0] += temp[d];
582581
}
583582

@@ -2557,34 +2556,19 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
25572556
moduli[i] = nativeParams[i]->GetModulus();
25582557
}
25592558

2560-
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP))};
2561-
std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);
2562-
2559+
std::vector<DCRTPoly::Integer> crtPowP;
25632560
if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO ||
25642561
cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) {
25652562
// Duhyeong: Support the case powP > 2^64
25662563
// Later we might need to use the NATIVE_INT=128 version of FHECKKSRNS::MakeAuxPlaintext for higher precision
25672564
int32_t logPowP = static_cast<int32_t>(ceil(log2(fabs(powP))));
2568-
// DCRTPoly::Integer intPowP;
2569-
int32_t logApprox_PowP;
2565+
25702566
if (logPowP > 64) {
25712567
// Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
25722568
logValid = (logPowP <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ?
25732569
logPowP :
25742570
LargeScalingFactorConstants::MAX_BITS_IN_WORD;
2575-
logApprox_PowP = logPowP - logValid;
2576-
approxFactor = pow(2, logApprox_PowP);
2577-
// Multiply scFactor in two steps: powP / approxFactor and then approxFactor
2578-
intPowP = std::llround(powP / approxFactor);
2579-
}
2580-
else {
2581-
intPowP = std::llround(powP);
2582-
}
2583-
2584-
// std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);
2585-
crtPowP.resize(numTowers, intPowP);
2586-
2587-
if (logPowP > 64) {
2571+
int32_t logApprox_PowP = logPowP - logValid;
25882572
if (logApprox_PowP > 0) {
25892573
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
25902574
logApprox_PowP :
@@ -2603,7 +2587,20 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
26032587
}
26042588
crtPowP = CKKSPackedEncoding::CRTMult(crtPowP, crtApprox, moduli);
26052589
}
2590+
else {
2591+
double approxFactor = pow(2, logApprox_PowP);
2592+
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP / approxFactor))};
2593+
crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
2594+
}
26062595
}
2596+
else {
2597+
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP))};
2598+
crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
2599+
}
2600+
}
2601+
else {
2602+
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP))};
2603+
crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
26072604
}
26082605

26092606
auto currPowP = crtPowP;

0 commit comments

Comments
 (0)