Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the estimation logic for hybrid key switching #883

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
@@ -166,11 +166,13 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
* @param extraModulusSize bit size for extra modulus in FLEXIBLEAUTOEXT (CKKS and BGV only)
* @param numPrimes number of moduli witout extraModulus
* @param auxBits size of auxiliar moduli used for hybrid key switching
* @param addOne should an extra bit be added (for CKKS and BGV)
*
* @return log2 of the modulus and number of RNS limbs.
*/
static std::pair<double, uint32_t> EstimateLogP(uint32_t numPartQ, double firstModulusSize, double dcrtBits,
double extraModulusSize, uint32_t numPrimes, uint32_t auxBits);
double extraModulusSize, uint32_t numPrimes, uint32_t auxBits,
bool addOne = false);

/*
* Estimates the extra modulus bitsize needed for threshold FHE noise flooding (only for BGV and BFV)
7 changes: 5 additions & 2 deletions src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
@@ -444,10 +444,13 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
qBound += cryptoParamsBGVRNS->EstimateMultipartyFloodingLogQ();

// we add an extra bit to account for the special logic of selecting the RNS moduli in BGV
qBound++;

uint32_t auxTowers = 0;
if (ksTech == HYBRID) {
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, dcrtBits, extraModSize, numPrimes, auxBits);
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, dcrtBits, extraModSize, numPrimes, auxBits, true);
qBound += std::get<0>(hybridKSInfo);
auxTowers = std::get<1>(hybridKSInfo);
}
@@ -486,7 +489,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
numPartQ, std::log2(moduliQ[0].ConvertToDouble()),
(moduliQ.size() > 1) ? std::log2(moduliQ[1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? std::log2(moduliQ[moduliQ.size() - 1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits);
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits, true);
newQBound += std::get<0>(hybridKSInfo);
}
} while (qBound < newQBound);
7 changes: 5 additions & 2 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
@@ -77,10 +77,13 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
uint32_t n = cyclOrder / 2;
uint32_t qBound = firstModSize + (numPrimes - 1) * scalingModSize + extraModSize;

// we add an extra bit to account for the alternating logic of selecting the RNS moduli in CKKS
qBound++;

// Estimate ciphertext modulus Q*P bound (in case of HYBRID P*Q)
if (ksTech == HYBRID) {
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, scalingModSize, extraModSize, numPrimes, auxBits);
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, scalingModSize, extraModSize,
numPrimes, auxBits, true);
qBound += std::get<0>(hybridKSInfo);
}

6 changes: 5 additions & 1 deletion src/pke/lib/schemerns/rns-cryptoparameters.cpp
Original file line number Diff line number Diff line change
@@ -387,7 +387,7 @@ uint64_t CryptoParametersRNS::FindAuxPrimeStep() const {

std::pair<double, uint32_t> CryptoParametersRNS::EstimateLogP(uint32_t numPartQ, double firstModulusSize,
double dcrtBits, double extraModulusSize,
uint32_t numPrimes, uint32_t auxBits) {
uint32_t numPrimes, uint32_t auxBits, bool addOne) {
// numPartQ can not be zero as there is a division by numPartQ
if (numPartQ == 0)
OPENFHE_THROW("numPartQ is zero");
@@ -426,6 +426,10 @@ std::pair<double, uint32_t> CryptoParametersRNS::EstimateLogP(uint32_t numPartQ,
maxBits = bits;
}

// we add an extra bit to account for for the special moduli selection logic in BGV and CKKS
if (addOne)
maxBits++;

// Select number of primes in auxiliary CRT basis
auto sizeP = static_cast<uint32_t>(std::ceil(maxBits / auxBits));