diff --git a/src/pke/examples/function-evaluation-composite-scaling.cpp b/src/pke/examples/function-evaluation-composite-scaling.cpp new file mode 100644 index 000000000..994adecf3 --- /dev/null +++ b/src/pke/examples/function-evaluation-composite-scaling.cpp @@ -0,0 +1,181 @@ +//================================================================================== +// BSD 2-Clause License +// +// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors +// +// All rights reserved. +// +// Author TPOC: contact@openfhe.org +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//================================================================================== + +/* + Example of evaluating arbitrary smooth functions with the Chebyshev approximation using CKKS. + */ + +#include "openfhe.h" + +#include +#include + +using namespace lbcrypto; + +void EvalLogisticExample(); + +void EvalFunctionExample(); + +int main(int argc, char* argv[]) { + EvalLogisticExample(); + EvalFunctionExample(); + return 0; +} + +// In this example, we evaluate the logistic function 1 / (1 + exp(-x)) on an input of doubles +void EvalLogisticExample() { + std::cout << "--------------------------------- EVAL LOGISTIC FUNCTION ---------------------------------" + << std::endl; + CCParams parameters; + + // We set a smaller ring dimension to improve performance for this example. + // In production environments, the security level should be set to + // HEStd_128_classic, HEStd_192_classic, or HEStd_256_classic for 128-bit, 192-bit, + // or 256-bit security, respectively. + parameters.SetSecurityLevel(HEStd_NotSet); + parameters.SetRingDim(1 << 10); +#if NATIVEINT == 128 + usint scalingModSize = 78; + usint firstModSize = 89; +#else + usint scalingModSize = 50; + usint firstModSize = 60; +#endif + parameters.SetScalingModSize(scalingModSize); + parameters.SetFirstModSize(firstModSize); + parameters.SetScalingTechnique(COMPOSITESCALINGAUTO); + parameters.SetRegisterWordSize(32); + + // Choosing a higher degree yields better precision, but a longer runtime. + uint32_t polyDegree = 16; + + // The multiplicative depth depends on the polynomial degree. + // See the FUNCTION_EVALUATION.md file for a table mapping polynomial degrees to multiplicative depths. + uint32_t multDepth = 6; + + parameters.SetMultiplicativeDepth(multDepth); + CryptoContext cc = GenCryptoContext(parameters); + cc->Enable(PKE); + cc->Enable(KEYSWITCH); + cc->Enable(LEVELEDSHE); + // We need to enable Advanced SHE to use the Chebyshev approximation. + cc->Enable(ADVANCEDSHE); + + auto keyPair = cc->KeyGen(); + // We need to generate mult keys to run Chebyshev approximations. + cc->EvalMultKeyGen(keyPair.secretKey); + + std::vector> input{-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0}; + size_t encodedLength = input.size(); + Plaintext plaintext = cc->MakeCKKSPackedPlaintext(input); + auto ciphertext = cc->Encrypt(keyPair.publicKey, plaintext); + + double lowerBound = -5; + double upperBound = 5; + auto result = cc->EvalLogistic(ciphertext, lowerBound, upperBound, polyDegree); + + Plaintext plaintextDec; + cc->Decrypt(keyPair.secretKey, result, &plaintextDec); + plaintextDec->SetLength(encodedLength); + + std::vector> expectedOutput( + {0.0179885, 0.0474289, 0.119205, 0.268936, 0.5, 0.731064, 0.880795, 0.952571, 0.982011}); + std::cout << "Expected output\n\t" << expectedOutput << std::endl; + + std::vector> finalResult = plaintextDec->GetCKKSPackedValue(); + std::cout << "Actual output\n\t" << finalResult << std::endl << std::endl; +} + +void EvalFunctionExample() { + std::cout << "--------------------------------- EVAL SQUARE ROOT FUNCTION ---------------------------------" + << std::endl; + CCParams parameters; + + // We set a smaller ring dimension to improve performance for this example. + // In production environments, the security level should be set to + // HEStd_128_classic, HEStd_192_classic, or HEStd_256_classic for 128-bit, 192-bit, + // or 256-bit security, respectively. + parameters.SetSecurityLevel(HEStd_NotSet); + parameters.SetRingDim(1 << 10); +#if NATIVEINT == 128 + usint scalingModSize = 78; + usint firstModSize = 89; +#else + usint scalingModSize = 50; + usint firstModSize = 60; +#endif + parameters.SetScalingModSize(scalingModSize); + parameters.SetFirstModSize(firstModSize); + parameters.SetScalingTechnique(COMPOSITESCALINGAUTO); + parameters.SetRegisterWordSize(32); + + // Choosing a higher degree yields better precision, but a longer runtime. + uint32_t polyDegree = 50; + + // The multiplicative depth depends on the polynomial degree. + // See the FUNCTION_EVALUATION.md file for a table mapping polynomial degrees to multiplicative depths. + uint32_t multDepth = 7; + + parameters.SetMultiplicativeDepth(multDepth); + CryptoContext cc = GenCryptoContext(parameters); + cc->Enable(PKE); + cc->Enable(KEYSWITCH); + cc->Enable(LEVELEDSHE); + // We need to enable Advanced SHE to use the Chebyshev approximation. + cc->Enable(ADVANCEDSHE); + + auto keyPair = cc->KeyGen(); + // We need to generate mult keys to run Chebyshev approximations. + cc->EvalMultKeyGen(keyPair.secretKey); + + std::vector> input{1, 2, 3, 4, 5, 6, 7, 8, 9}; + size_t encodedLength = input.size(); + Plaintext plaintext = cc->MakeCKKSPackedPlaintext(input); + auto ciphertext = cc->Encrypt(keyPair.publicKey, plaintext); + + double lowerBound = 0; + double upperBound = 10; + + // We can input any lambda function, which inputs a double and returns a double. + auto result = cc->EvalChebyshevFunction([](double x) -> double { return std::sqrt(x); }, ciphertext, lowerBound, + upperBound, polyDegree); + + Plaintext plaintextDec; + cc->Decrypt(keyPair.secretKey, result, &plaintextDec); + plaintextDec->SetLength(encodedLength); + + std::vector> expectedOutput( + {1, 1.414213, 1.732050, 2, 2.236067, 2.449489, 2.645751, 2.828427, 3}); + std::cout << "Expected output\n\t" << expectedOutput << std::endl; + + std::vector> finalResult = plaintextDec->GetCKKSPackedValue(); + std::cout << "Actual output\n\t" << finalResult << std::endl << std::endl; +} diff --git a/src/pke/examples/linearwsum-evaluation-composite-scaling.cpp b/src/pke/examples/linearwsum-evaluation-composite-scaling.cpp new file mode 100644 index 000000000..e23b59848 --- /dev/null +++ b/src/pke/examples/linearwsum-evaluation-composite-scaling.cpp @@ -0,0 +1,125 @@ +//================================================================================== +// BSD 2-Clause License +// +// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors +// +// All rights reserved. +// +// Author TPOC: contact@openfhe.org +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//================================================================================== + +/* + Example of polynomial evaluation using CKKS. + */ + +#define PROFILE // turns on the reporting of timing results + +#include "openfhe.h" + +#include +#include + +using namespace lbcrypto; + +int main(int argc, char* argv[]) { + TimeVar t; + + double timeEvalLinearWSum(0.0); + + std::cout << "\n======EXAMPLE FOR EVAL LINEAR WEIGHTED SUM========\n" << std::endl; + + CCParams parameters; + parameters.SetMultiplicativeDepth(1); + parameters.SetScalingModSize(50); + parameters.SetBatchSize(8); + parameters.SetSecurityLevel(HEStd_NotSet); + parameters.SetRingDim(2048); + parameters.SetScalingTechnique(COMPOSITESCALINGAUTO); + parameters.SetFirstModSize(60); + parameters.SetRegisterWordSize(32); + + CryptoContext cc = GenCryptoContext(parameters); + cc->Enable(PKE); + cc->Enable(KEYSWITCH); + cc->Enable(LEVELEDSHE); + cc->Enable(ADVANCEDSHE); + + std::vector>> input; + + input.push_back({0.5, 0.7, 0.9, 0.95, 0.93, 1.3}); + input.push_back({1.2, 1.7, -0.9, 0.85, -0.63, 2}); + input.push_back({0.5, 0, 1.9, 2.95, -3.93, 3.3}); + input.push_back({1.5, 0.7, 1.9, 2.95, -3.78, 3.3}); + input.push_back({0.5, 2.7, 1.9, 0.0, -3.43, 1.3}); + input.push_back({0.5, 0.7, -1.9, 2.95, 1.96, 0.0}); + input.push_back({0.0, 0.0, 1.0, 0.0, 0.0, 0.0}); + + size_t encodedLength = input.size(); + + std::vector coefficients({0.15, 0.75, 1.25, 1, 0, 0.5, 0.5}); + + auto keyPair = cc->KeyGen(); + + std::cout << "Generating evaluation key for homomorphic multiplication..."; + cc->EvalMultKeyGen(keyPair.secretKey); + std::cout << "Completed." << std::endl; + + std::vector> ciphertextVec; + for (usint i = 0; i < encodedLength; ++i) { + Plaintext plaintext = cc->MakeCKKSPackedPlaintext(input[i]); + ciphertextVec.push_back(cc->Encrypt(keyPair.publicKey, plaintext)); + } + + TIC(t); + + auto result = cc->EvalLinearWSum(ciphertextVec, coefficients); + + timeEvalLinearWSum = TOC(t); + + std::vector> unencIP; + for (usint i = 0; i < input[0].size(); ++i) { + std::complex x = 0; + for (usint j = 0; j < encodedLength; ++j) { + x += input[j][i] * coefficients[j]; + } + unencIP.push_back(x); + } + + Plaintext plaintextDec; + + cc->Decrypt(keyPair.secretKey, result, &plaintextDec); + + plaintextDec->SetLength(encodedLength); + + std::cout << std::setprecision(10) << std::endl; + + std::cout << "\n Result of evaluating a linear weighted sum with coefficients " << coefficients << " \n"; + std::cout << plaintextDec << std::endl; + + std::cout << "\n Expected result: " << unencIP << std::endl; + + std::cout << "\n Evaluation time: " << timeEvalLinearWSum << " ms" << std::endl; + + return 0; +} diff --git a/src/pke/examples/polynomial-evaluation-high-precision-composite-scaling.cpp b/src/pke/examples/polynomial-evaluation-high-precision-composite-scaling.cpp new file mode 100644 index 000000000..fed7902b5 --- /dev/null +++ b/src/pke/examples/polynomial-evaluation-high-precision-composite-scaling.cpp @@ -0,0 +1,250 @@ +//================================================================================== +// BSD 2-Clause License +// +// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors +// +// All rights reserved. +// +// Author TPOC: contact@openfhe.org +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//================================================================================== + +/* + Example of polynomial evaluation using CKKS. + */ + +#define PROFILE // turns on the reporting of timing results + +#include "openfhe.h" + +#include +#include + +using namespace lbcrypto; + +void printPrimeModuliChain(const DCRTPoly& poly) { + int num_primes = poly.GetNumOfElements(); + double total_bit_len = 0.0; + for (int i = 0; i < num_primes; i++) { + auto qi = poly.GetParams()->GetParams()[i]->GetModulus(); + std::cout << "q_" << i << ": " << qi << ", log q_" << i << ": " << log(qi.ConvertToDouble()) / log(2) + << std::endl; + total_bit_len += log(qi.ConvertToDouble()) / log(2); + } + std::cout << "Total bit length: " << total_bit_len << std::endl; +} + +double getScaleApproxError(const DCRTPoly& poly, uint32_t numPrimes, uint32_t compositeDegree, uint32_t firstModSize, + uint32_t scalingModSize) { + double delta0 = std::pow(2.0, static_cast(firstModSize)); + double delta = std::pow(2.0, static_cast(scalingModSize)); + // uint32_t numPrimes = poly.GetNumOfElements(); + auto q = poly.GetParams()->GetParams(); + + std::cout << "numPrimes=" << numPrimes << " compositeDegree=" << compositeDegree << " firstModSize=" << firstModSize + << " scalingModSize=" << scalingModSize << std::endl; + + double prod = q[0]->GetModulus().ConvertToDouble(); + std::cout << "q0_0: " << prod; + for (uint32_t d = 1; d < compositeDegree; ++d) { + std::cout << " q0_" << d << ": " << q[d]->GetModulus().ConvertToDouble(); + prod *= q[d]->GetModulus().ConvertToDouble(); + } + std::cout << "\n"; + double cumApproxError = std::abs(delta0 - prod); + + std::cout << "q0: " << prod << " delta0: " << delta0 << " approxErr=" << std::abs(delta0 - prod) << std::endl; + + for (uint32_t i = compositeDegree; i < numPrimes; i += compositeDegree) { + prod = q[i]->GetModulus().ConvertToDouble(); + std::cout << "q" << i / compositeDegree << "_0: " << prod; + for (uint32_t d = 1; d < compositeDegree; ++d) { + std::cout << " q" << i / compositeDegree << "_" << d << ": " << q[i + d]->GetModulus().ConvertToDouble(); + prod *= q[i + d]->GetModulus().ConvertToDouble(); + } + std::cout << "\n"; + cumApproxError += std::abs(delta - prod); + std::cout << "q" << i / compositeDegree << ": " << prod << " delta: " << delta + << " approxErr=" << std::abs(delta - prod) << std::endl; + } + + std::cout << "Average distance to scaling factor: " + << cumApproxError / ((numPrimes - compositeDegree) / compositeDegree) << std::endl; + + return cumApproxError / ((numPrimes - compositeDegree) / compositeDegree); +} + +int main(int argc, char* argv[]) { + TimeVar t; + + double timeEvalPoly1(0.0), timeEvalPoly2(0.0); + // Parameters for d=4 + // uint32_t firstModSize = 106; + // uint32_t scalingModSize = 104; + // uint32_t registerWordSize = 32; + // Parameters for d=3 + uint32_t firstModSize = 96; + uint32_t scalingModSize = 80; + uint32_t registerWordSize = 32; + + std::cout << "\n======EXAMPLE FOR EVALPOLY========\n" << std::endl; + + uint32_t multDepth = 6; + int argcCount = 0; + if (argc > 1) { + while (argcCount < argc) { + uint32_t paramValue = atoi(argv[argcCount]); + switch (argcCount) { + case 1: + firstModSize = paramValue; + std::cout << "Setting First Mod Size: " << firstModSize << std::endl; + break; + case 2: + scalingModSize = paramValue; + std::cout << "Setting Scaling Mod Size: " << scalingModSize << std::endl; + break; + case 3: + registerWordSize = paramValue; + std::cout << "Setting Register Word Size: " << registerWordSize << std::endl; + break; + case 4: + multDepth = paramValue; + std::cout << "Setting Multiplicative Depth: " << multDepth << std::endl; + break; + default: + std::cout << "Invalid option" << std::endl; + break; + } + argcCount += 1; + std::cout << "argcCount: " << argcCount << std::endl; + } + std::cout << "Complete !" << std::endl; + } + else { + std::cout << "Using default parameters" << std::endl; + std::cout << "First Mod Size: " << firstModSize << std::endl; + std::cout << "Scaling Mod Size: " << scalingModSize << std::endl; + std::cout << "Register Word Size: " << registerWordSize << std::endl; + std::cout << "Multiplicative Depth: " << multDepth << std::endl; + std::cout << "Usage: " << argv[0] << " [firstModSize] [scalingModSize] [registerWordSize] [multDepth]" + << std::endl; + } + + CCParams parameters; + parameters.SetMultiplicativeDepth(multDepth); + parameters.SetFirstModSize(firstModSize); + parameters.SetScalingModSize(scalingModSize); + + parameters.SetRegisterWordSize(registerWordSize); + parameters.SetScalingTechnique(COMPOSITESCALINGAUTO); + + CryptoContext cc = GenCryptoContext(parameters); + cc->Enable(PKE); + cc->Enable(KEYSWITCH); + cc->Enable(LEVELEDSHE); + cc->Enable(ADVANCEDSHE); + + const auto cryptoParamsCKKSRNS = std::dynamic_pointer_cast(cc->GetCryptoParameters()); + uint32_t compositeDegree = cryptoParamsCKKSRNS->GetCompositeDegree(); + std::cout << "Composite Degree: " << compositeDegree << "\nPrime Moduli Bit Length: " + << static_cast(scalingModSize) / cryptoParamsCKKSRNS->GetCompositeDegree() + << "\nTarget HW Arch Word Size: " << registerWordSize << std::endl; + + std::vector> input({0.5, 0.7, 0.9, 0.95, 0.93}); + + size_t encodedLength = input.size(); + + std::vector coefficients1({0.15, 0.75, 0, 1.25, 0, 0, 1, 0, 1, 2, 0, 1, 0, 0, 0, 0, 1}); + std::vector coefficients2({1, 2, 3, 4, 5, -1, -2, -3, -4, -5, + 0.1, 0.2, 0.3, 0.4, 0.5, -0.1, -0.2, -0.3, -0.4, -0.5, + 0.1, 0.2, 0.3, 0.4, 0.5, -0.1, -0.2, -0.3, -0.4, -0.5}); + // std::vector coefficients2({0, 0, 0, 0, 0, -0, -0, -0, -0, -0, + // 0., 0., 0., 0., 0., -0., -0., -0., -0., -0., + // 0., 0., 0., 0., 0., -0., -0., -0., -0., -0.}); + Plaintext plaintext1 = cc->MakeCKKSPackedPlaintext(input); + + auto keyPair = cc->KeyGen(); + + std::cout << "Generating evaluation key for homomorphic multiplication..."; + cc->EvalMultKeyGen(keyPair.secretKey); + std::cout << "Completed." << std::endl; + + const std::vector& ckkspk = keyPair.publicKey->GetPublicElements(); + std::cout << "Moduli chain of pk: " << std::endl; + printPrimeModuliChain(ckkspk[0]); + + double avgScaleError = getScaleApproxError(ckkspk[0], (multDepth + 1) * compositeDegree, compositeDegree, + firstModSize, scalingModSize); + std::cout << "Average Scale Error: " << avgScaleError << std::endl; + + auto ciphertext1 = cc->Encrypt(keyPair.publicKey, plaintext1); + + TIC(t); + + auto result = cc->EvalPoly(ciphertext1, coefficients1); + + timeEvalPoly1 = TOC(t); + + TIC(t); + + auto result2 = cc->EvalPoly(ciphertext1, coefficients2); + + timeEvalPoly2 = TOC(t); + + Plaintext plaintextDec; + + cc->Decrypt(keyPair.secretKey, result, &plaintextDec); + + plaintextDec->SetLength(encodedLength); + + Plaintext plaintextDec2; + + cc->Decrypt(keyPair.secretKey, result2, &plaintextDec2); + + plaintextDec2->SetLength(encodedLength); + + std::cout << std::setprecision(15) << std::endl; + + std::cout << "\n Original Plaintext #1: \n"; + std::cout << plaintext1 << std::endl; + + std::cout << "\n Result of evaluating a polynomial with coefficients " << coefficients1 << " \n"; + std::cout << plaintextDec << std::endl; + + std::cout << "\n Expected result: (0.70519107, 1.38285078, 3.97211180, " + "5.60215665, 4.86357575) " + << std::endl; + + std::cout << "\n Evaluation time: " << timeEvalPoly1 << " ms" << std::endl; + + std::cout << "\n Result of evaluating a polynomial with coefficients " << coefficients2 << " \n"; + std::cout << plaintextDec2 << std::endl; + + std::cout << "\n Expected result: (3.4515092326, 5.3752765397, 4.8993108833, " + "3.2495023573, 4.0485229982) " + << std::endl; + + std::cout << "\n Evaluation time: " << timeEvalPoly2 << " ms" << std::endl; + + return 0; +} diff --git a/src/pke/examples/simple-composite-scaling.cpp b/src/pke/examples/simple-composite-scaling.cpp index 9dc4fce04..efe68a6ec 100644 --- a/src/pke/examples/simple-composite-scaling.cpp +++ b/src/pke/examples/simple-composite-scaling.cpp @@ -41,7 +41,7 @@ using namespace lbcrypto; -int main() { +int main(int argc, char* argv[]) { // Step 1: Setup CryptoContext // A. Specify main parameters @@ -107,6 +107,46 @@ int main() { */ uint32_t registerWordSize = 32; + int argcCount = 1; + if (argc > 1) { + while (argcCount < argc) { + uint32_t paramValue = atoi(argv[argcCount]); + switch (argcCount) { + case 1: + firstModSize = paramValue; + std::cout << "Setting First Mod Size: " << firstModSize << std::endl; + break; + case 2: + scaleModSize = paramValue; + std::cout << "Setting Scaling Mod Size: " << scaleModSize << std::endl; + break; + case 3: + registerWordSize = paramValue; + std::cout << "Setting Register Word Size: " << registerWordSize << std::endl; + break; + case 4: + multDepth = paramValue; + std::cout << "Setting Multiplicative Depth: " << multDepth << std::endl; + break; + default: + std::cout << "Invalid option" << std::endl; + break; + } + argcCount += 1; + std::cout << "argcCount: " << argcCount << std::endl; + } + std::cout << "Complete !" << std::endl; + } + else { + std::cout << "Using default parameters" << std::endl; + std::cout << "First Mod Size: " << firstModSize << std::endl; + std::cout << "Scaling Mod Size: " << scaleModSize << std::endl; + std::cout << "Register Word Size: " << registerWordSize << std::endl; + std::cout << "Multiplicative Depth: " << multDepth << std::endl; + std::cout << "Usage: " << argv[0] << " [firstModSize] [scalingModSize] [registerWordSize] [multDepth]" + << std::endl; + } + /* A4) Desired security level based on FHE standards. * This parameter can take four values. Three of the possible values * correspond to 128-bit, 192-bit, and 256-bit security, and the fourth value diff --git a/src/pke/include/cryptocontext.h b/src/pke/include/cryptocontext.h index 8c4413e21..bb088e67e 100644 --- a/src/pke/include/cryptocontext.h +++ b/src/pke/include/cryptocontext.h @@ -66,6 +66,9 @@ #include #include #include +#ifdef DEBUG_KEY + #include +#endif namespace lbcrypto { @@ -442,6 +445,20 @@ class CryptoContextImpl : public Serializable { return p; } + /** + * GetCompositeDegree: get composite degree of the current scheme crypto context. + * @return integer value corresponding to composite degree + */ + uint32_t GetCompositeDegreeFromCtxt() const { + const auto cryptoParams = std::dynamic_pointer_cast(params); + if (!cryptoParams) { + std::string errorMsg(std::string("std::dynamic_pointer_cast() failed")); + OPENFHE_THROW(errorMsg); + } + + return cryptoParams->GetCompositeDegree(); + } + PrivateKey privateKey; public: @@ -2339,7 +2356,7 @@ class CryptoContextImpl : public Serializable { Ciphertext Rescale(ConstCiphertext ciphertext) const { ValidateCiphertext(ciphertext); - return GetScheme()->ModReduce(ciphertext, BASE_NUM_LEVELS_TO_DROP); + return GetScheme()->ModReduce(ciphertext, GetCompositeDegreeFromCtxt()); } /** @@ -2351,7 +2368,7 @@ class CryptoContextImpl : public Serializable { void RescaleInPlace(Ciphertext& ciphertext) const { ValidateCiphertext(ciphertext); - GetScheme()->ModReduceInPlace(ciphertext, BASE_NUM_LEVELS_TO_DROP); + GetScheme()->ModReduceInPlace(ciphertext, GetCompositeDegreeFromCtxt()); } /** @@ -2362,7 +2379,7 @@ class CryptoContextImpl : public Serializable { Ciphertext ModReduce(ConstCiphertext ciphertext) const { ValidateCiphertext(ciphertext); - return GetScheme()->ModReduce(ciphertext, BASE_NUM_LEVELS_TO_DROP); + return GetScheme()->ModReduce(ciphertext, GetCompositeDegreeFromCtxt()); } /** @@ -2372,11 +2389,12 @@ class CryptoContextImpl : public Serializable { void ModReduceInPlace(Ciphertext& ciphertext) const { ValidateCiphertext(ciphertext); - GetScheme()->ModReduceInPlace(ciphertext, BASE_NUM_LEVELS_TO_DROP); + GetScheme()->ModReduceInPlace(ciphertext, GetCompositeDegreeFromCtxt()); } /** * LevelReduce - drops unnecessary RNS limbs (levels) from the ciphertext and evaluation key + * Note: the number of levels to drop is multiplied by the composite degree in CKKS when using COMPOSITESCALING* scaling techniques. * @param ciphertext input ciphertext. Supported only in BGV/CKKS. * @param evalKey input evaluation key (modified in place) * @returns the ciphertext with reduced number opf RNS limbs @@ -2385,11 +2403,12 @@ class CryptoContextImpl : public Serializable { size_t levels = 1) const { ValidateCiphertext(ciphertext); - return GetScheme()->LevelReduce(ciphertext, evalKey, levels); + return GetScheme()->LevelReduce(ciphertext, evalKey, levels * GetCompositeDegreeFromCtxt()); } /** * LevelReduceInPlace - drops unnecessary RNS limbs (levels) from the ciphertext and evaluation key. Supported only in BGV/CKKS. + * Note: the number of levels to drop is multiplied by the composite degree in CKKS when using COMPOSITESCALING* scaling techniques. * @param ciphertext input ciphertext (modified in place) * @param evalKey input evaluation key (modified in place) */ @@ -2398,7 +2417,8 @@ class CryptoContextImpl : public Serializable { if (levels <= 0) { return; } - GetScheme()->LevelReduceInPlace(ciphertext, evalKey, levels); + + GetScheme()->LevelReduceInPlace(ciphertext, evalKey, levels * GetCompositeDegreeFromCtxt()); } /** * Compress - Reduces the size of ciphertext modulus to minimize the diff --git a/src/pke/include/scheme/ckksrns/ckksrns-advancedshe.h b/src/pke/include/scheme/ckksrns/ckksrns-advancedshe.h index ad2a2c163..9f57b5b4a 100644 --- a/src/pke/include/scheme/ckksrns/ckksrns-advancedshe.h +++ b/src/pke/include/scheme/ckksrns/ckksrns-advancedshe.h @@ -47,6 +47,9 @@ class AdvancedSHECKKSRNS : public AdvancedSHERNS { public: virtual ~AdvancedSHECKKSRNS() {} + Ciphertext EvalMultMany(const std::vector>& ciphertextVec, + const std::vector>& evalKeyVec) const override; + //------------------------------------------------------------------------------ // LINEAR WEIGHTED SUM //------------------------------------------------------------------------------ diff --git a/src/pke/include/scheme/ckksrns/ckksrns-parametergeneration.h b/src/pke/include/scheme/ckksrns/ckksrns-parametergeneration.h index 7b1cdbb95..ec41f4a6a 100644 --- a/src/pke/include/scheme/ckksrns/ckksrns-parametergeneration.h +++ b/src/pke/include/scheme/ckksrns/ckksrns-parametergeneration.h @@ -34,6 +34,7 @@ #include "schemerns/rns-parametergeneration.h" +#include #include #include @@ -44,6 +45,15 @@ namespace lbcrypto { class ParameterGenerationCKKSRNS : public ParameterGenerationRNS { +protected: + void CompositePrimeModuliGen(std::vector& moduliQ, std::vector& rootsQ, + uint32_t compositeDegree, uint32_t numPrimes, uint32_t firstModSize, uint32_t dcrtBits, + uint32_t cyclOrder, uint32_t registerWordSize) const; + + void SinglePrimeModuliGen(std::vector& moduliQ, std::vector& rootsQ, + ScalingTechnique scalTech, uint32_t numPrimes, uint32_t firstModSize, uint32_t dcrtBits, + uint32_t cyclOrder, uint32_t extraModsize) const; + public: virtual ~ParameterGenerationCKKSRNS() {} diff --git a/src/pke/include/schemerns/rns-cryptoparameters.h b/src/pke/include/schemerns/rns-cryptoparameters.h index cbcbda348..d32a11c9d 100644 --- a/src/pke/include/schemerns/rns-cryptoparameters.h +++ b/src/pke/include/schemerns/rns-cryptoparameters.h @@ -169,12 +169,15 @@ class CryptoParametersRNS : public CryptoParametersRLWE { * @param extraModulusSize bit size for extra modulus in FLEXIBLEAUTOEXT (CKKS and BGV only) * @param numPrimes number of moduli witout extraModulus * @param auxBits size of auxiliar moduli used for hybrid key switching + * @param scalTech scaling technique + * @param compositeDegree number of moduli in each level (CKKS only) * @param addOne should an extra bit be added (for CKKS and BGV) * * @return log2 of the modulus and number of RNS limbs. */ static std::pair EstimateLogP(uint32_t numPartQ, double firstModulusSize, double dcrtBits, double extraModulusSize, uint32_t numPrimes, uint32_t auxBits, + ScalingTechnique scalTech, uint32_t compositeDegree = 1, bool addOne = false); /* @@ -607,7 +610,7 @@ class CryptoParametersRNS : public CryptoParametersRLWE { /** * Method to retrieve the scaling factor of level l. * For FIXEDMANUAL scaling technique method always returns 2^p, where p corresponds to plaintext modulus - * @param l For FLEXIBLEAUTO scaling technique the level whose scaling factor we want to learn. + * @param l For FLEXIBLEAUTO and COMPOSITESCALING scaling techniques the level whose scaling factor we want to learn. * Levels start from 0 (no scaling done - all towers) and go up to K-1, where K is the number of towers supported. * @return the scaling factor. */ @@ -642,7 +645,7 @@ class CryptoParametersRNS : public CryptoParametersRLWE { /** * Method to retrieve the modulus to be dropped of level l. * For FIXEDMANUAL rescaling technique method always returns 2^p, where p corresponds to plaintext modulus - * @param l index of modulus to be dropped for FLEXIBLEAUTO scaling technique + * @param l index of modulus to be dropped for FLEXIBLEAUTO and COMPOSITESCALING scaling techniques * @return the precomputed table */ double GetModReduceFactor(uint32_t l = 0) const { diff --git a/src/pke/include/schemerns/rns-leveledshe.h b/src/pke/include/schemerns/rns-leveledshe.h index 37b71f382..ac32388c1 100644 --- a/src/pke/include/schemerns/rns-leveledshe.h +++ b/src/pke/include/schemerns/rns-leveledshe.h @@ -291,6 +291,15 @@ class LeveledSHERNS : public LeveledSHEBase { Ciphertext Compress(ConstCiphertext ciphertext, size_t towersLeft) const override; + //////////////////////////////////////// + // SHE LEVELED ComposedEvalMult + //////////////////////////////////////// + + using LeveledSHEBase::ComposedEvalMult; + + Ciphertext ComposedEvalMult(ConstCiphertext ciphertext1, ConstCiphertext ciphertext2, + const EvalKey evalKey) const override; + protected: ///////////////////////////////////// // RNS Core diff --git a/src/pke/lib/encoding/ckkspackedencoding.cpp b/src/pke/lib/encoding/ckkspackedencoding.cpp index b8bd4b576..aa6701b4f 100644 --- a/src/pke/lib/encoding/ckkspackedencoding.cpp +++ b/src/pke/lib/encoding/ckkspackedencoding.cpp @@ -43,6 +43,9 @@ #include #include #include +#include +#include +#include namespace lbcrypto { @@ -186,7 +189,7 @@ bool CKKSPackedEncoding::Encode() { im = im64 >> (-pRemaining); } else { - int128_t pPowRemaining = ((int64_t)1) << pRemaining; + int128_t pPowRemaining = (static_cast(1)) << pRemaining; im = pPowRemaining * im64; } @@ -396,13 +399,13 @@ bool CKKSPackedEncoding::Encode() { int32_t MAX_LOG_STEP = 60; if (logApprox > 0) { int32_t logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP; - DCRTPoly::Integer intStep = uint64_t(1) << logStep; + DCRTPoly::Integer intStep = static_cast(1) << logStep; std::vector crtApprox(numTowers, intStep); logApprox -= logStep; while (logApprox > 0) { int32_t logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP; - DCRTPoly::Integer intStep = uint64_t(1) << logStep; + DCRTPoly::Integer intStep = static_cast(1) << logStep; std::vector crtSF(numTowers, intStep); crtApprox = CRTMult(crtApprox, crtSF, moduli); logApprox -= logStep; @@ -434,7 +437,8 @@ bool CKKSPackedEncoding::Decode(size_t noiseScaleDeg, double scalingFactor, Scal std::vector> curValues(slots); if (this->typeFlag == IsNativePoly) { - if (scalTech == FLEXIBLEAUTO || scalTech == FLEXIBLEAUTOEXT) + if (scalTech == FLEXIBLEAUTO || scalTech == FLEXIBLEAUTOEXT || scalTech == COMPOSITESCALINGAUTO || + scalTech == COMPOSITESCALINGMANUAL) powP = pow(scalingFactor, -1); else powP = pow(2, -p); @@ -466,7 +470,8 @@ bool CKKSPackedEncoding::Decode(size_t noiseScaleDeg, double scalingFactor, Scal // we will bring down the scaling factor to 2^p double scalingFactorPre = 0.0; - if (scalTech == FLEXIBLEAUTO || scalTech == FLEXIBLEAUTOEXT) + if (scalTech == FLEXIBLEAUTO || scalTech == FLEXIBLEAUTOEXT || scalTech == COMPOSITESCALINGAUTO || + scalTech == COMPOSITESCALINGMANUAL) scalingFactorPre = pow(scalingFactor, -1) * pow(2, p); else scalingFactorPre = pow(2, -p * (noiseScaleDeg - 1)); diff --git a/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp b/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp index ca20f32e4..eeebc2fb8 100644 --- a/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp +++ b/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp @@ -40,6 +40,10 @@ BFV implementation. See https://eprint.iacr.org/2021/204 for details. #include "scheme/bfvrns/bfvrns-parametergeneration.h" #include "scheme/scheme-utils.h" +#include +#include +#include + namespace lbcrypto { bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr> cryptoParams, @@ -125,7 +129,8 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr(std::ceil(std::ceil(logq) / dcrtBits)); // set the number of digits uint32_t numPartQ = ComputeNumLargeDigits(numDigits, k - 1); - auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, dcrtBits, dcrtBits, 0, k, auxBits); + auto hybridKSInfo = + CryptoParametersRNS::EstimateLogP(numPartQ, dcrtBits, dcrtBits, 0, k, auxBits, scalTech); logq += std::get<0>(hybridKSInfo); } return static_cast( diff --git a/src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp b/src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp index c617d53ca..55397939d 100644 --- a/src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp +++ b/src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp @@ -39,6 +39,11 @@ BGV implementation. See https://eprint.iacr.org/2021/204 for details. #include "scheme/bgvrns/bgvrns-cryptoparameters.h" #include "scheme/bgvrns/bgvrns-parametergeneration.h" +#include +#include +#include +#include + namespace lbcrypto { uint32_t ParameterGenerationBGVRNS::computeRingDimension( @@ -151,7 +156,7 @@ uint64_t ParameterGenerationBGVRNS::getCyclicOrder(const uint32_t ringDimension, if (pow2ptm < cyclOrder) pow2ptm = cyclOrder; - lcmCyclOrderPtm = (uint64_t)pow2ptm * plaintextModulus; + lcmCyclOrderPtm = static_cast(pow2ptm) * plaintextModulus; } else { lcmCyclOrderPtm = cyclOrder; @@ -451,8 +456,8 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr(hybridKSInfo); auxTowers = std::get<1>(hybridKSInfo); } @@ -484,7 +489,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr 1) ? std::log2(moduliQ[1].ConvertToDouble()) : 0, (scalTech == FLEXIBLEAUTOEXT) ? std::log2(moduliQ[moduliQ.size() - 1].ConvertToDouble()) : 0, - (scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits, false); + (scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits, scalTech, 1, false); newQBound += std::get<0>(hybridKSInfo); } } while (qBound < newQBound); @@ -511,7 +516,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr(pow2ptm) * plaintextModulus; // Get the largest prime with size less or equal to firstModSize bits. moduliQ[0] = LastPrime(firstModSize, modulusOrder); @@ -592,10 +597,10 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr(ptm) % cyclOrder; b = 1; while (a != 1) { - a = ((uint64_t)(a * ptm)) % cyclOrder; + a = static_cast(a * ptm) % cyclOrder; b++; } diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-advancedshe.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-advancedshe.cpp index 9d0f2ff09..e628fb880 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-advancedshe.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-advancedshe.cpp @@ -42,8 +42,49 @@ CKKS implementation. See https://eprint.iacr.org/2020/1118 for details. #include "schemebase/base-scheme.h" +#include + namespace lbcrypto { +Ciphertext AdvancedSHECKKSRNS::EvalMultMany(const std::vector>& ciphertextVec, + const std::vector>& evalKeys) const { + const size_t inSize = ciphertextVec.size(); + + if (inSize == 0) + OPENFHE_THROW("Input ciphertext vector is empty."); + + if (inSize == 1) + return ciphertextVec[0]; + + const size_t lim = inSize * 2 - 2; + std::vector> ciphertextMultVec; + ciphertextMultVec.resize(inSize - 1); + + auto algo = ciphertextVec[0]->GetCryptoContext()->GetScheme(); + const auto cryptoParams = std::dynamic_pointer_cast(ciphertextVec[0]->GetCryptoParameters()); + uint32_t levelsToDrop = cryptoParams->GetCompositeDegree(); + + size_t ctrIndex = 0; + size_t i = 0; + for (; i < (inSize - 1); i += 2) { + ciphertextMultVec[ctrIndex] = algo->EvalMultAndRelinearize(ciphertextVec[i], ciphertextVec[(i + 1)], evalKeys); + algo->ModReduceInPlace(ciphertextMultVec[ctrIndex++], levelsToDrop); + } + if (i < inSize) { + ciphertextMultVec[ctrIndex] = + algo->EvalMultAndRelinearize(ciphertextVec[i], ciphertextMultVec[i + 1 - inSize], evalKeys); + algo->ModReduceInPlace(ciphertextMultVec[ctrIndex++], levelsToDrop); + i += 2; + } + for (; i < lim; i += 2) { + ciphertextMultVec[ctrIndex] = + algo->EvalMultAndRelinearize(ciphertextMultVec[i - inSize], ciphertextMultVec[i + 1 - inSize], evalKeys); + algo->ModReduceInPlace(ciphertextMultVec[ctrIndex++], levelsToDrop); + } + + return ciphertextMultVec.back(); +} + //------------------------------------------------------------------------------ // LINEAR WEIGHTED SUM //------------------------------------------------------------------------------ @@ -63,6 +104,8 @@ Ciphertext AdvancedSHECKKSRNS::EvalLinearWSumMutable(std::vector& constants) const { const auto cryptoParams = std::dynamic_pointer_cast(ciphertexts[0]->GetCryptoParameters()); + uint32_t compositeDegree = cryptoParams->GetCompositeDegree(); + auto cc = ciphertexts[0]->GetCryptoContext(); auto algo = cc->GetScheme(); @@ -89,7 +132,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalLinearWSumMutable(std::vectorGetNoiseScaleDeg() == 2) { for (uint32_t i = 0; i < ciphertexts.size(); i++) { - algo->ModReduceInternalInPlace(ciphertexts[i], BASE_NUM_LEVELS_TO_DROP); + algo->ModReduceInternalInPlace(ciphertexts[i], compositeDegree); } } } @@ -148,7 +191,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyLinear(ConstCiphertext(std::floor(std::log2(i))); int64_t rem = i % powerOf2; if (indices[rem - 1] == 0) indices[rem - 1] = 1; @@ -156,7 +199,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyLinear(ConstCiphertext(std::floor(std::log2(rem))); rem = rem % powerOf2; if (indices[rem - 1] == 0) indices[rem - 1] = 1; @@ -166,8 +209,10 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyLinear(ConstCiphertext> powers(k); - powers[0] = x->Clone(); - auto cc = x->GetCryptoContext(); + powers[0] = x->Clone(); + auto cc = x->GetCryptoContext(); + auto cryptoParams = std::dynamic_pointer_cast(x->GetCryptoParameters()); + uint32_t compositeDegree = cryptoParams->GetCompositeDegree(); // computes all powers up to k for x for (size_t i = 2; i <= k; i++) { @@ -179,10 +224,10 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyLinear(ConstCiphertext(std::floor(std::log2(i))); int64_t rem = i % powerOf2; usint levelDiff = powers[powerOf2 - 1]->GetLevel() - powers[rem - 1]->GetLevel(); - cc->LevelReduceInPlace(powers[rem - 1], nullptr, levelDiff); + cc->LevelReduceInPlace(powers[rem - 1], nullptr, levelDiff / compositeDegree); powers[i - 1] = cc->EvalMult(powers[powerOf2 - 1], powers[rem - 1]); cc->ModReduceInPlace(powers[i - 1]); @@ -194,7 +239,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyLinear(ConstCiphertextGetLevel() - powers[i - 1]->GetLevel(); - cc->LevelReduceInPlace(powers[i - 1], nullptr, levelDiff); + cc->LevelReduceInPlace(powers[i - 1], nullptr, levelDiff / compositeDegree); } } @@ -228,19 +273,19 @@ Ciphertext AdvancedSHECKKSRNS::InnerEvalPolyPS(ConstCiphertext xkm(int32_t(k2m2k + k) + 1, 0.0); + std::vector xkm(static_cast(k2m2k + k) + 1, 0.0); xkm.back() = 1; auto divqr = LongDivisionPoly(coefficients, xkm); // Subtract x^{k(2^{m-1} - 1)} from r std::vector r2 = divqr->r; - if (int32_t(k2m2k - Degree(divqr->r)) <= 0) { - r2[int32_t(k2m2k)] -= 1; + if (static_cast(k2m2k - Degree(divqr->r)) <= 0) { + r2[static_cast(k2m2k)] -= 1; r2.resize(Degree(r2) + 1); } else { - r2.resize(int32_t(k2m2k + 1), 0.0); + r2.resize(static_cast(k2m2k + 1), 0.0); r2.back() = -1; } @@ -249,7 +294,7 @@ Ciphertext AdvancedSHECKKSRNS::InnerEvalPolyPS(ConstCiphertext s2 = divcs->r; - s2.resize(int32_t(k2m2k + 1), 0.0); + s2.resize(static_cast(k2m2k + 1), 0.0); s2.back() = 1; Ciphertext cu; @@ -393,7 +438,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyPS(ConstCiphertext x, else { // non-power of 2 indices[i - 1] = 1; - int64_t powerOf2 = 1 << (int64_t)std::floor(std::log2(i)); + int64_t powerOf2 = 1 << static_cast(std::floor(std::log2(i))); int64_t rem = i % powerOf2; if (indices[rem - 1] == 0) indices[rem - 1] = 1; @@ -401,7 +446,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyPS(ConstCiphertext x, // while rem is not a power of 2 // set indices required to compute rem to 1 while ((rem & (rem - 1))) { - powerOf2 = 1 << (int64_t)std::floor(std::log2(rem)); + powerOf2 = 1 << static_cast(std::floor(std::log2(rem))); rem = rem % powerOf2; if (indices[rem - 1] == 0) indices[rem - 1] = 1; @@ -412,6 +457,8 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyPS(ConstCiphertext x, std::vector> powers(k); powers[0] = x->Clone(); auto cc = x->GetCryptoContext(); + uint32_t compositeDegree = + std::dynamic_pointer_cast(x->GetCryptoParameters())->GetCompositeDegree(); // computes all powers up to k for x for (size_t i = 2; i <= k; i++) { @@ -423,10 +470,10 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyPS(ConstCiphertext x, else { if (indices[i - 1] == 1) { // non-power of 2 - int64_t powerOf2 = 1 << (int64_t)std::floor(std::log2(i)); + int64_t powerOf2 = 1 << static_cast(std::floor(std::log2(i))); int64_t rem = i % powerOf2; usint levelDiff = powers[powerOf2 - 1]->GetLevel() - powers[rem - 1]->GetLevel(); - cc->LevelReduceInPlace(powers[rem - 1], nullptr, levelDiff); + cc->LevelReduceInPlace(powers[rem - 1], nullptr, levelDiff / compositeDegree); powers[i - 1] = cc->EvalMult(powers[powerOf2 - 1], powers[rem - 1]); cc->ModReduceInPlace(powers[i - 1]); } @@ -479,18 +526,18 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyPS(ConstCiphertext x, f2.back() = 1; // Divide f2 by x^{k*2^{m-1}} - std::vector xkm(int32_t(k2m2k + k) + 1, 0.0); + std::vector xkm(static_cast(k2m2k + k) + 1, 0.0); xkm.back() = 1; auto divqr = LongDivisionPoly(f2, xkm); // Subtract x^{k(2^{m-1} - 1)} from r std::vector r2 = divqr->r; - if (int32_t(k2m2k - Degree(divqr->r)) <= 0) { - r2[int32_t(k2m2k)] -= 1; + if (static_cast(k2m2k - Degree(divqr->r)) <= 0) { + r2[static_cast(k2m2k)] -= 1; r2.resize(Degree(r2) + 1); } else { - r2.resize(int32_t(k2m2k + 1), 0.0); + r2.resize(static_cast(k2m2k + 1), 0.0); r2.back() = -1; } @@ -499,7 +546,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalPolyPS(ConstCiphertext x, // Add x^{k(2^{m-1} - 1)} to s std::vector s2 = divcs->r; - s2.resize(int32_t(k2m2k + 1), 0.0); + s2.resize(static_cast(k2m2k + 1), 0.0); s2.back() = 1; // Evaluate c at u @@ -658,6 +705,8 @@ Ciphertext AdvancedSHECKKSRNS::EvalChebyshevSeriesLinear(ConstCipherte } Ciphertext yReduced = T[0]->Clone(); + uint32_t compositeDegree = + std::dynamic_pointer_cast(x->GetCryptoParameters())->GetCompositeDegree(); // Computes Chebyshev polynomials up to degree k // for y: T_1(y) = y, T_2(y), ... , T_k(y) @@ -706,7 +755,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalChebyshevSeriesLinear(ConstCipherte } for (size_t i = 1; i < k; i++) { usint levelDiff = T[k - 1]->GetLevel() - T[i - 1]->GetLevel(); - cc->LevelReduceInPlace(T[i - 1], nullptr, levelDiff); + cc->LevelReduceInPlace(T[i - 1], nullptr, levelDiff / compositeDegree); } // perform scalar multiplication for the highest-order term @@ -734,23 +783,25 @@ Ciphertext AdvancedSHECKKSRNS::InnerEvalChebyshevPS(ConstCiphertext>& T, std::vector>& T2) const { auto cc = x->GetCryptoContext(); + uint32_t compositeDegree = + std::dynamic_pointer_cast(x->GetCryptoParameters())->GetCompositeDegree(); // Compute k*2^{m-1}-k because we use it a lot uint32_t k2m2k = k * (1 << (m - 1)) - k; // Divide coefficients by T^{k*2^{m-1}} - std::vector Tkm(int32_t(k2m2k + k) + 1, 0.0); + std::vector Tkm(static_cast(k2m2k + k) + 1, 0.0); Tkm.back() = 1; auto divqr = LongDivisionChebyshev(coefficients, Tkm); // Subtract x^{k(2^{m-1} - 1)} from r std::vector r2 = divqr->r; - if (int32_t(k2m2k - Degree(divqr->r)) <= 0) { - r2[int32_t(k2m2k)] -= 1; + if (static_cast(k2m2k - Degree(divqr->r)) <= 0) { + r2[static_cast(k2m2k)] -= 1; r2.resize(Degree(r2) + 1); } else { - r2.resize(int32_t(k2m2k + 1), 0.0); + r2.resize(static_cast(k2m2k + 1), 0.0); r2.back() = -1; } @@ -759,7 +810,7 @@ Ciphertext AdvancedSHECKKSRNS::InnerEvalChebyshevPS(ConstCiphertext s2 = divcs->r; - s2.resize(int32_t(k2m2k + 1), 0.0); + s2.resize(static_cast(k2m2k + 1), 0.0); s2.back() = 1; // Evaluate c at u @@ -792,7 +843,7 @@ Ciphertext AdvancedSHECKKSRNS::InnerEvalChebyshevPS(ConstCiphertextEvalAddInPlace(cu, divcs->q.front() / 2); // Need to reduce levels up to the level of T2[m-1]. usint levelDiff = T2[m - 1]->GetLevel() - cu->GetLevel(); - cc->LevelReduceInPlace(cu, nullptr, levelDiff); + cc->LevelReduceInPlace(cu, nullptr, levelDiff / compositeDegree); flag_c = true; } @@ -1013,18 +1064,18 @@ Ciphertext AdvancedSHECKKSRNS::EvalChebyshevSeriesPS(ConstCiphertext Tkm(int32_t(k2m2k + k) + 1, 0.0); + std::vector Tkm(static_cast(k2m2k + k) + 1, 0.0); Tkm.back() = 1; auto divqr = LongDivisionChebyshev(f2, Tkm); // Subtract x^{k(2^{m-1} - 1)} from r std::vector r2 = divqr->r; - if (int32_t(k2m2k - Degree(divqr->r)) <= 0) { - r2[int32_t(k2m2k)] -= 1; + if (static_cast(k2m2k - Degree(divqr->r)) <= 0) { + r2[static_cast(k2m2k)] -= 1; r2.resize(Degree(r2) + 1); } else { - r2.resize(int32_t(k2m2k + 1), 0.0); + r2.resize(static_cast(k2m2k + 1), 0.0); r2.back() = -1; } @@ -1033,7 +1084,7 @@ Ciphertext AdvancedSHECKKSRNS::EvalChebyshevSeriesPS(ConstCiphertext s2 = divcs->r; - s2.resize(int32_t(k2m2k + 1), 0.0); + s2.resize(static_cast(k2m2k + 1), 0.0); s2.back() = 1; // Evaluate c at u diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-cryptoparameters.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-cryptoparameters.cpp index 5849af17c..7f60648e5 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-cryptoparameters.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-cryptoparameters.cpp @@ -51,7 +51,6 @@ void CryptoParametersCKKSRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Sca size_t sizeQ = GetElementParams()->GetParams().size(); uint32_t compositeDegree = this->GetCompositeDegree(); - compositeDegree = (compositeDegree == 0) ? 1 : compositeDegree; std::vector moduliQ(sizeQ); std::vector rootsQ(sizeQ); @@ -90,13 +89,15 @@ void CryptoParametersCKKSRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Sca m_scalTechnique == COMPOSITESCALINGAUTO || m_scalTechnique == COMPOSITESCALINGMANUAL) { m_scalingFactorsReal.resize(sizeQ); - if ((sizeQ == 1) && (extraBits == 0)) { + if ((sizeQ == 1) && (extraBits == 0) && (m_scalTechnique != COMPOSITESCALINGAUTO) && + (m_scalTechnique != COMPOSITESCALINGMANUAL)) { // mult depth = 0 and FLEXIBLEAUTO // when multiplicative depth = 0, we use the scaling mod size instead of modulus size // Plaintext modulus is used in EncodingParamsImpl to store the exponent p of the scaling factor m_scalingFactorsReal[0] = pow(2, GetPlaintextModulus()); } - else if ((sizeQ == 2) && (extraBits > 0)) { + else if ((sizeQ == 2) && (extraBits > 0) && (m_scalTechnique != COMPOSITESCALINGAUTO) && + (m_scalTechnique != COMPOSITESCALINGMANUAL)) { // mult depth = 0 and FLEXIBLEAUTOEXT // when multiplicative depth = 0, we use the scaling mod size instead of modulus size // Plaintext modulus is used in EncodingParamsImpl to store the exponent p of the scaling factor @@ -105,19 +106,45 @@ void CryptoParametersCKKSRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Sca } else { m_scalingFactorsReal[0] = moduliQ[sizeQ - 1].ConvertToDouble(); - if (extraBits > 0) + if (m_scalTechnique == COMPOSITESCALINGAUTO || m_scalTechnique == COMPOSITESCALINGMANUAL) { + for (uint32_t j = 1; j < compositeDegree; j++) { + m_scalingFactorsReal[0] *= moduliQ[sizeQ - j - 1].ConvertToDouble(); + } + } + if (extraBits > 0 && m_scalTechnique != COMPOSITESCALINGAUTO && m_scalTechnique != COMPOSITESCALINGMANUAL) m_scalingFactorsReal[1] = moduliQ[sizeQ - 2].ConvertToDouble(); + const double lastPresetFactor = (extraBits == 0) ? m_scalingFactorsReal[0] : m_scalingFactorsReal[1]; // number of levels with pre-calculated factors - const size_t numPresetFactors = (extraBits == 0) ? 1 : 2; + const size_t numPresetFactors = (extraBits == 0 || (m_scalTechnique == COMPOSITESCALINGAUTO || + m_scalTechnique == COMPOSITESCALINGMANUAL)) ? + 1 : + 2; for (size_t k = numPresetFactors; k < sizeQ; k++) { - double prevSF = m_scalingFactorsReal[k - 1]; - m_scalingFactorsReal[k] = prevSF * prevSF / moduliQ[sizeQ - k].ConvertToDouble(); - double ratio = m_scalingFactorsReal[k] / lastPresetFactor; - if (ratio <= 0.5 || ratio >= 2.0) - OPENFHE_THROW( - "FLEXIBLEAUTO cannot support this number of levels in this parameter setting. Please use FIXEDMANUAL or FIXEDAUTO instead."); + if (m_scalTechnique == COMPOSITESCALINGAUTO || m_scalTechnique == COMPOSITESCALINGMANUAL) { + if (k % compositeDegree == 0) { + double prevSF = m_scalingFactorsReal[k - compositeDegree]; + m_scalingFactorsReal[k] = prevSF * prevSF; + for (uint32_t j = 0; j < compositeDegree; j++) { + m_scalingFactorsReal[k] /= moduliQ[sizeQ - k + j].ConvertToDouble(); + } + } + else { + m_scalingFactorsReal[k] = 1; + } + } + else { + double prevSF = m_scalingFactorsReal[k - 1]; + m_scalingFactorsReal[k] = prevSF * prevSF / moduliQ[sizeQ - k].ConvertToDouble(); + } + + if (m_scalTechnique == FLEXIBLEAUTO || m_scalTechnique == FLEXIBLEAUTOEXT) { + double ratio = m_scalingFactorsReal[k] / lastPresetFactor; + if (ratio <= 0.5 || ratio >= 2.0) + OPENFHE_THROW( + "FLEXIBLEAUTO cannot support this number of levels in this parameter setting. Please use FIXEDMANUAL or FIXEDAUTO instead."); + } } } diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp index 91a7219f5..e7312790d 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp @@ -245,7 +245,10 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons // -gsimple-template-names // -gsplit-dwarf int32_t logApprox = 0; - const double res = std::fabs(operand * scFactor); + const double res = (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO || + cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) ? + std::fabs(scFactor) : + std::fabs(operand * scFactor); if (res > 0) { int32_t logSF = static_cast(std::ceil(std::log2(res))); int32_t logValid = (logSF <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ? @@ -253,7 +256,8 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons LargeScalingFactorConstants::MAX_BITS_IN_WORD; logApprox = logSF - logValid; } - double approxFactor = pow(2, logApprox); + int32_t logApprox_cp = logApprox; + double approxFactor = pow(2, logApprox); DCRTPoly::Integer scConstant = static_cast(operand * scFactor / approxFactor + 0.5); std::vector crtConstant(sizeQl, scConstant); @@ -285,11 +289,52 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons return crtConstant; } - DCRTPoly::Integer intScFactor = static_cast(scFactor + 0.5); - std::vector crtScFactor(sizeQl, intScFactor); + // COMPOSITESCALING support to 128-bit scaling factor + if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO || + cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) { + int32_t logSF_cp = static_cast(std::ceil(std::log2(res))); + if (logSF_cp < 64) { + DCRTPoly::Integer intScFactor = static_cast(scFactor + 0.5); + std::vector crtScFactor(sizeQl, intScFactor); + for (usint i = 1; i < ciphertext->GetNoiseScaleDeg(); i++) { + crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli); + } + } + else { + // Multiply scFactor in two steps: scFactor / approxFactor and then approxFactor + DCRTPoly::Integer intScFactor = static_cast(scFactor / approxFactor + 0.5); + std::vector crtScFactor(sizeQl, intScFactor); + for (usint i = 1; i < ciphertext->GetNoiseScaleDeg(); i++) { + crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli); + } + if (logApprox_cp > 0) { + int32_t logStep = (logApprox_cp <= LargeScalingFactorConstants::MAX_LOG_STEP) ? + logApprox_cp : + LargeScalingFactorConstants::MAX_LOG_STEP; + DCRTPoly::Integer intStep = static_cast(1) << logStep; + std::vector crtApprox(sizeQl, intStep); + logApprox_cp -= logStep; + + while (logApprox_cp > 0) { + int32_t logStep = (logApprox_cp <= LargeScalingFactorConstants::MAX_LOG_STEP) ? + logApprox_cp : + LargeScalingFactorConstants::MAX_LOG_STEP; + DCRTPoly::Integer intStep = static_cast(1) << logStep; + std::vector crtSF(sizeQl, intStep); + crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli); + logApprox_cp -= logStep; + } + crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtApprox, moduli); + } + } + } + else { + DCRTPoly::Integer intScFactor = static_cast(scFactor + 0.5); + std::vector crtScFactor(sizeQl, intScFactor); - for (usint i = 1; i < ciphertext->GetNoiseScaleDeg(); i++) { - crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli); + for (usint i = 1; i < ciphertext->GetNoiseScaleDeg(); i++) { + crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli); + } } return crtConstant; diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp index c45de167d..2e6684cea 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp @@ -42,6 +42,8 @@ CKKS implementation. See https://eprint.iacr.org/2020/1118 for details. #include #include #include +#include +#include namespace lbcrypto { @@ -69,6 +71,35 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptrConfigureCompositeDegree(firstModSize); + uint32_t compositeDegree = cryptoParamsCKKSRNS->GetCompositeDegree(); + uint32_t registerWordSize = cryptoParamsCKKSRNS->GetRegisterWordSize(); + + if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) { + if (compositeDegree > 2 && scalingModSize < 55) { + std::string errorMsg = "COMPOSITESCALING Warning:"; + errorMsg += + "There will not be enough prime moduli for composite degree > 2 and scaling factor < 55 -- prime moduli too small."; + errorMsg += + "Prime moduli size must generally be greater than 22, especially for larger multiplicative depth."; + errorMsg += "Try increasing the scaling factor (scalingModSize)."; + errorMsg += "Also, feel free to use COMPOSITESCALINGMANUAL at your own risk."; + OPENFHE_THROW(errorMsg); + } + else if (compositeDegree == 1 && registerWordSize < 64) { + OPENFHE_THROW( + "This COMPOSITESCALING* version does not support composite degree == 1 with register size < 64."); + } + else if (compositeDegree < 1) { + OPENFHE_THROW("Composite degree must be greater than or equal to 1."); + } + + if (registerWordSize < 24 && scalTech == COMPOSITESCALINGAUTO) { + std::string errorMsg = "Register word size must be greater than or equal to 24 for COMPOSITESCALINGAUTO."; + errorMsg += "Otherwise, try it with COMPOSITESCALINGMANUAL."; + OPENFHE_THROW(errorMsg); + } + } + if ((PREMode != INDCPA) && (PREMode != NOT_SET)) { std::stringstream s; s << "This PRE mode " << PREMode << " is not supported for CKKSRNS"; @@ -92,7 +123,8 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptrGetCompositeDegree() == 2) ? 2 : 4; qBound += ceil(ceil(static_cast(qBound) / numPartQ) / (tmpFactor * auxBits)) * tmpFactor * auxBits; @@ -141,10 +173,251 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr moduliQ(vecSize); std::vector rootsQ(vecSize); + if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) { + if (compositeDegree > 1) { + CompositePrimeModuliGen(moduliQ, rootsQ, compositeDegree, numPrimes, firstModSize, dcrtBits, cyclOrder, + registerWordSize); + } + else { + SinglePrimeModuliGen(moduliQ, rootsQ, scalTech, numPrimes, firstModSize, dcrtBits, cyclOrder, extraModSize); + } + } + else { + SinglePrimeModuliGen(moduliQ, rootsQ, scalTech, numPrimes, firstModSize, dcrtBits, cyclOrder, extraModSize); + } + + auto paramsDCRT = std::make_shared>(cyclOrder, moduliQ, rootsQ); + + cryptoParamsCKKSRNS->SetElementParams(paramsDCRT); + + // if no batch size was specified, we set batchSize = n/2 by default (for full packing) + if (encodingParams->GetBatchSize() == 0) { + uint32_t batchSize = n / 2; + EncodingParams encodingParamsNew( + std::make_shared(encodingParams->GetPlaintextModulus(), batchSize)); + cryptoParamsCKKSRNS->SetEncodingParams(encodingParamsNew); + } + + cryptoParamsCKKSRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraModSize); + + // Validate the ring dimension found using estimated logQ(P) against actual logQ(P) + if (stdLevel != HEStd_NotSet) { + uint32_t logActualQ = (ksTech == HYBRID) ? cryptoParamsCKKSRNS->GetParamsQP()->GetModulus().GetMSB() : + cryptoParamsCKKSRNS->GetElementParams()->GetModulus().GetMSB(); + + uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ); + if (n < nActual) { + std::string errMsg("The ring dimension found using estimated logQ(P) ["); + errMsg += std::to_string(n) + "] does does not meet security requirements. "; + errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to "; + errMsg += std::to_string(nActual) + "."; + + OPENFHE_THROW(errMsg); + } + } + + return true; +} + +void ParameterGenerationCKKSRNS::CompositePrimeModuliGen(std::vector& moduliQ, + std::vector& rootsQ, usint compositeDegree, + usint numPrimes, usint firstModSize, usint dcrtBits, + usint cyclOrder, usint registerWordSize) const { + std::unordered_set moduliQRecord; + + // Sample q0, the first primes in the modulus chain + uint32_t remBits = dcrtBits; + + for (uint32_t d = 1; d <= compositeDegree; ++d) { + uint32_t qBitSize = std::ceil(static_cast(remBits) / (compositeDegree - d + 1)); + NativeInteger q = FirstPrime(qBitSize, cyclOrder); + q = PreviousPrime(q, cyclOrder); + while (std::log2(q.ConvertToDouble()) > registerWordSize || std::log2(q.ConvertToDouble()) > qBitSize || + moduliQRecord.find(q.ConvertToInt()) != moduliQRecord.end()) { + q = PreviousPrime(q, cyclOrder); + } + moduliQ[numPrimes - d] = q; + rootsQ[numPrimes - d] = RootOfUnity(cyclOrder, moduliQ[numPrimes - d]); + moduliQRecord.emplace(q.ConvertToInt()); + remBits -= std::ceil(std::log2(q.ConvertToDouble())); + } + + if (numPrimes > 1) { + std::vector qPrev(std::ceil(static_cast(compositeDegree) / 2)); + std::vector qNext(compositeDegree - static_cast(qPrev.size())); + + // Prep to compute initial scaling factor + double sf = moduliQ[numPrimes - 1].ConvertToDouble(); + for (uint32_t d = 2; d <= compositeDegree; ++d) { + sf *= moduliQ[numPrimes - d].ConvertToDouble(); + } + + bool flag = true; + for (usint i = numPrimes - compositeDegree; i >= 2 * compositeDegree; i -= compositeDegree) { + // Compute initial scaling factor + sf = static_cast(std::pow(sf, 2)); + for (usint d = 0; d < compositeDegree; ++d) { + sf /= moduliQ[i + d].ConvertToDouble(); + } + + auto sf_sqrt = std::pow(sf, 1.0 / compositeDegree); + + NativeInteger sfInt = std::llround(sf_sqrt); + NativeInteger sfRem = sfInt.Mod(cyclOrder); + + double primeProduct = 1.0; + std::unordered_set qCurrentRecord; // current prime tracker + + for (size_t step = 0; step < qPrev.size(); ++step) { + qPrev[step] = sfInt - sfRem + NativeInteger(1) - NativeInteger(cyclOrder); + do { + try { + qPrev[step] = lbcrypto::PreviousPrime(qPrev[step], cyclOrder); + } + catch (const OpenFHEException& ex) { + OPENFHE_THROW( + "COMPOSITE SCALING previous prime sampling error. Try increasing scaling factor (scalingModSize)."); + } + } while (std::log2(qPrev[step].ConvertToDouble()) > registerWordSize || + moduliQRecord.find(qPrev[step].ConvertToInt()) != moduliQRecord.end() || + qCurrentRecord.find(qPrev[step].ConvertToInt()) != qCurrentRecord.end()); + qCurrentRecord.emplace(qPrev[step].ConvertToInt()); + primeProduct *= qPrev[step].ConvertToDouble(); + } + + for (size_t step = 0; step < qNext.size(); ++step) { + qNext[step] = sfInt - sfRem + NativeInteger(1) + NativeInteger(cyclOrder); + do { + try { + qNext[step] = lbcrypto::NextPrime(qNext[step], cyclOrder); + } + catch (const OpenFHEException& ex) { + OPENFHE_THROW( + "COMPOSITE SCALING next prime sampling error. Try increasing scaling factor (scalingModSize)."); + } + } while (std::log2(qNext[step].ConvertToDouble()) > registerWordSize || + moduliQRecord.find(qNext[step].ConvertToInt()) != moduliQRecord.end() || + qCurrentRecord.find(qNext[step].ConvertToInt()) != qCurrentRecord.end()); + qCurrentRecord.emplace(qNext[step].ConvertToInt()); + primeProduct *= qNext[step].ConvertToDouble(); + } + + if (flag == false) { + NativeInteger qPrevNext = NativeInteger(qNext[qNext.size() - 1].ConvertToInt()); + while (primeProduct > sf) { + do { + qCurrentRecord.erase(qPrevNext.ConvertToInt()); // constant time + try { + qPrevNext = lbcrypto::PreviousPrime(qPrevNext, cyclOrder); + } + catch (const OpenFHEException& ex) { + OPENFHE_THROW( + "COMPOSITE SCALING previous prime sampling error. Try increasing scaling factor (scalingModSize)."); + } + } while (std::log2(qPrevNext.ConvertToDouble()) > registerWordSize || + moduliQRecord.find(qPrevNext.ConvertToInt()) != moduliQRecord.end() || + qCurrentRecord.find(qPrevNext.ConvertToInt()) != qCurrentRecord.end()); + qCurrentRecord.emplace(qPrevNext.ConvertToInt()); + + primeProduct /= qNext[qNext.size() - 1].ConvertToDouble(); + qNext[qNext.size() - 1] = qPrevNext; + primeProduct *= qPrevNext.ConvertToDouble(); + } + + uint32_t m = qPrev.size(); + for (uint32_t d = 1; d <= m; ++d) { + moduliQ[i - d] = qPrev[d - 1]; + } + for (uint32_t d = m + 1; d <= compositeDegree; ++d) { + moduliQ[i - d] = qNext[d - (m + 1)]; + } + + for (uint32_t d = 1; d <= compositeDegree; ++d) { + rootsQ[i - d] = RootOfUnity(cyclOrder, moduliQ[i - d]); + moduliQRecord.emplace(moduliQ[i - d].ConvertToInt()); + } + + flag = true; + } + else { + NativeInteger qNextPrev = NativeInteger(qPrev[qPrev.size() - 1].ConvertToInt()); + + while (primeProduct < sf) { + do { + qCurrentRecord.erase(qNextPrev.ConvertToInt()); // constant time + try { + qNextPrev = lbcrypto::NextPrime(qNextPrev, cyclOrder); + } + catch (const OpenFHEException& ex) { + OPENFHE_THROW( + "COMPOSITE SCALING next prime sampling error. Try increasing scaling factor (scalingModSize)."); + } + } while (std::log2(qNextPrev.ConvertToDouble()) > registerWordSize || + moduliQRecord.find(qNextPrev.ConvertToInt()) != moduliQRecord.end() || + qCurrentRecord.find(qNextPrev.ConvertToInt()) != qCurrentRecord.end()); + qCurrentRecord.emplace(qNextPrev.ConvertToInt()); + + primeProduct /= qPrev[qPrev.size() - 1].ConvertToDouble(); + qPrev[qPrev.size() - 1] = qNextPrev; + primeProduct *= qNextPrev.ConvertToDouble(); + } + + uint32_t m = qPrev.size(); + for (uint32_t d = 1; d <= m; ++d) { + moduliQ[i - d] = qPrev[d - 1]; + } + for (uint32_t d = m + 1; d <= compositeDegree; ++d) { + moduliQ[i - d] = qNext[d - (m + 1)]; + } + + for (uint32_t d = 1; d <= compositeDegree; ++d) { + rootsQ[i - d] = RootOfUnity(cyclOrder, moduliQ[i - d]); + moduliQRecord.emplace(moduliQ[i - d].ConvertToInt()); + } + + flag = false; + } + } // for loop + } // if numPrimes > 1 + + if (firstModSize == dcrtBits) { // this requires dcrtBits < 60 + OPENFHE_THROW("firstModSize must be > scalingModSize."); + } + else { + remBits = static_cast(firstModSize); + for (uint32_t d = 1; d <= compositeDegree; ++d) { + uint32_t qBitSize = std::ceil(static_cast(remBits) / (compositeDegree - d + 1)); + // Find next prime + NativeInteger nextInteger = FirstPrime(qBitSize, cyclOrder); + nextInteger = PreviousPrime(nextInteger, cyclOrder); + + while (std::log2(nextInteger.ConvertToDouble()) > qBitSize || + std::log2(nextInteger.ConvertToDouble()) > registerWordSize || + moduliQRecord.find(nextInteger.ConvertToInt()) != moduliQRecord.end()) + nextInteger = PreviousPrime(nextInteger, cyclOrder); + + // Store prime + moduliQ[d - 1] = nextInteger; + rootsQ[d - 1] = RootOfUnity(cyclOrder, moduliQ[d - 1]); + // Keep track of existing primes + moduliQRecord.emplace(moduliQ[d - 1].ConvertToInt()); + remBits -= qBitSize; + } + } + + return; +} + +void ParameterGenerationCKKSRNS::SinglePrimeModuliGen(std::vector& moduliQ, + std::vector& rootsQ, ScalingTechnique scalTech, + uint32_t numPrimes, uint32_t firstModSize, uint32_t dcrtBits, + uint32_t cyclOrder, uint32_t extraModSize) const { NativeInteger q = FirstPrime(dcrtBits, cyclOrder); moduliQ[numPrimes - 1] = q; rootsQ[numPrimes - 1] = RootOfUnity(cyclOrder, moduliQ[numPrimes - 1]); @@ -257,43 +530,6 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr>(cyclOrder, moduliQ, rootsQ); - - cryptoParamsCKKSRNS->SetElementParams(paramsDCRT); - - // if no batch size was specified, we set batchSize = n/2 by default (for full packing) - if (encodingParams->GetBatchSize() == 0) { - uint32_t batchSize = n / 2; - EncodingParams encodingParamsNew( - std::make_shared(encodingParams->GetPlaintextModulus(), batchSize)); - cryptoParamsCKKSRNS->SetEncodingParams(encodingParamsNew); - } - - cryptoParamsCKKSRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraModSize); - - // Validate the ring dimension found using estimated logQ(P) against actual logQ(P) - if (stdLevel != HEStd_NotSet) { - uint32_t logActualQ = 0; - if (ksTech == HYBRID) { - logActualQ = cryptoParamsCKKSRNS->GetParamsQP()->GetModulus().GetMSB(); - } - else { - logActualQ = cryptoParamsCKKSRNS->GetElementParams()->GetModulus().GetMSB(); - } - - uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ); - if (n < nActual) { - std::string errMsg("The ring dimension found using estimated logQ(P) ["); - errMsg += std::to_string(n) + "] does does not meet security requirements. "; - errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to "; - errMsg += std::to_string(nActual) + "."; - - OPENFHE_THROW(errMsg); - } - } - - return true; } } // namespace lbcrypto diff --git a/src/pke/lib/schemebase/base-advancedshe.cpp b/src/pke/lib/schemebase/base-advancedshe.cpp index b14b5891a..78a7ca5ae 100644 --- a/src/pke/lib/schemebase/base-advancedshe.cpp +++ b/src/pke/lib/schemebase/base-advancedshe.cpp @@ -34,14 +34,23 @@ #include "cryptocontext.h" #include "schemebase/base-scheme.h" +#include +#include +#include +#include +#include + namespace lbcrypto { template Ciphertext AdvancedSHEBase::EvalAddMany(const std::vector>& ciphertextVec) const { const size_t inSize = ciphertextVec.size(); - if (ciphertextVec.size() < 1) - OPENFHE_THROW("Input ciphertext vector size should be 1 or more"); + if (inSize == 0) + OPENFHE_THROW("Input ciphertext vector is empty."); + + if (inSize == 1) + return ciphertextVec[0]; const size_t lim = inSize * 2 - 2; std::vector> ciphertextSumVec; @@ -103,7 +112,7 @@ Ciphertext AdvancedSHEBase::EvalMultMany(const std::vectorEvalMultAndRelinearize( i < inSize ? ciphertextVec[i] : ciphertextMultVec[i - inSize], i + 1 < inSize ? ciphertextVec[i + 1] : ciphertextMultVec[i + 1 - inSize], evalKeys); - algo->ModReduceInPlace(ciphertextMultVec[ctrIndex++], 1); + algo->ModReduceInPlace(ciphertextMultVec[ctrIndex++], BASE_NUM_LEVELS_TO_DROP); } return ciphertextMultVec.back(); @@ -395,7 +404,8 @@ Ciphertext AdvancedSHEBase::EvalMerge(const std::vectorEvalAdd( - ciphertextMerged, algo->EvalAtIndex(algo->EvalMult(ciphertextVec[i], plaintext), -(int32_t)i, evalKeyMap)); + ciphertextMerged, + algo->EvalAtIndex(algo->EvalMult(ciphertextVec[i], plaintext), -static_cast(i), evalKeyMap)); } return ciphertextMerged; diff --git a/src/pke/lib/schemebase/base-leveledshe.cpp b/src/pke/lib/schemebase/base-leveledshe.cpp index 29f0cf909..00cab2553 100644 --- a/src/pke/lib/schemebase/base-leveledshe.cpp +++ b/src/pke/lib/schemebase/base-leveledshe.cpp @@ -35,6 +35,13 @@ #include "cryptocontext.h" #include "schemebase/base-scheme.h" +#include +#include +#include +#include +#include +#include + namespace lbcrypto { ///////////////////////////////////////// diff --git a/src/pke/lib/schemerns/rns-cryptoparameters.cpp b/src/pke/lib/schemerns/rns-cryptoparameters.cpp index 19bcf35b1..99013366f 100644 --- a/src/pke/lib/schemerns/rns-cryptoparameters.cpp +++ b/src/pke/lib/schemerns/rns-cryptoparameters.cpp @@ -35,8 +35,32 @@ #include "cryptocontext.h" #include "schemerns/rns-cryptoparameters.h" +#include +#include +#include +#include + namespace lbcrypto { +uint32_t UpdateSizeP(uint32_t compositeDegree, uint32_t sizeP) { + switch (compositeDegree) { + case 0: // not allowed + OPENFHE_THROW(std::string("Composite degree d = 0 is not allowed.")); + case 1: // not composite + break; + case 2: // composite degree == 2 + sizeP += (sizeP % 2); + break; + case 3: // composite degree == 3 + sizeP += (sizeP % 3); + break; + default: // composite degree > 3 + sizeP += (sizeP % 4); + break; + } + return sizeP; +} + void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, ScalingTechnique scalTech, EncryptionTechnique encTech, MultiplicationTechnique multTech, uint32_t numPartQ, uint32_t auxBits, uint32_t extraBits) { @@ -128,7 +152,11 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling maxBits = bits; } // Select number of primes in auxiliary CRT basis - uint32_t sizeP = static_cast(std::ceil(static_cast(maxBits) / auxBits)); + uint32_t sizeP = static_cast(std::ceil(static_cast(maxBits) / auxBits)); + if (GetScalingTechnique() == COMPOSITESCALINGAUTO || GetScalingTechnique() == COMPOSITESCALINGMANUAL) { + sizeP = UpdateSizeP(m_compositeDegree, sizeP); + } + uint64_t primeStep = FindAuxPrimeStep(); // Choose special primes in auxiliary basis and compute their roots @@ -387,7 +415,9 @@ uint64_t CryptoParametersRNS::FindAuxPrimeStep() const { std::pair CryptoParametersRNS::EstimateLogP(uint32_t numPartQ, double firstModulusSize, double dcrtBits, double extraModulusSize, - uint32_t numPrimes, uint32_t auxBits, bool addOne) { + uint32_t numPrimes, uint32_t auxBits, + ScalingTechnique scalTech, uint32_t compositeDegree, + bool addOne) { // numPartQ can not be zero as there is a division by numPartQ if (numPartQ == 0) OPENFHE_THROW("numPartQ is zero"); @@ -421,7 +451,8 @@ std::pair CryptoParametersRNS::EstimateLogP(uint32_t numPartQ, size_t endTower = ((j + 1) * numPerPartQ - 1 < sizeQ) ? (j + 1) * numPerPartQ - 1 : sizeQ - 1; // sum qi elements qi[startTower] + ... + qi[endTower] inclusive. the end element should be qi.begin()+(endTower+1) - uint32_t bits = static_cast(std::accumulate(qi.begin() + startTower, qi.begin() + (endTower + 1), 0.0)); + uint32_t bits = + static_cast(std::accumulate(qi.begin() + startTower, qi.begin() + (endTower + 1), 0.0)); if (bits > maxBits) maxBits = bits; } @@ -433,6 +464,9 @@ std::pair CryptoParametersRNS::EstimateLogP(uint32_t numPartQ, // Select number of primes in auxiliary CRT basis auto sizeP = static_cast(std::ceil(static_cast(maxBits) / auxBits)); + if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) { + sizeP = UpdateSizeP(compositeDegree, sizeP); + } return std::make_pair(sizeP * auxBits, sizeP); } diff --git a/src/pke/lib/schemerns/rns-leveledshe.cpp b/src/pke/lib/schemerns/rns-leveledshe.cpp index 85ed30e7b..9a78a83b9 100644 --- a/src/pke/lib/schemerns/rns-leveledshe.cpp +++ b/src/pke/lib/schemerns/rns-leveledshe.cpp @@ -34,6 +34,9 @@ #include "cryptocontext.h" #include "schemerns/rns-leveledshe.h" +#include +#include + namespace lbcrypto { ///////////////////////////////////////// @@ -212,7 +215,13 @@ Ciphertext LeveledSHERNS::EvalSquare(ConstCiphertext ciphert } auto c = ciphertext->Clone(); - ModReduceInternalInPlace(c, BASE_NUM_LEVELS_TO_DROP); + if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO || + cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) { + ModReduceInternalInPlace(c, cryptoParams->GetCompositeDegree()); + } + else { + ModReduceInternalInPlace(c, BASE_NUM_LEVELS_TO_DROP); + } return EvalSquareCore(c); } @@ -222,7 +231,13 @@ Ciphertext LeveledSHERNS::EvalSquareMutable(Ciphertext& ciph if (cryptoParams->GetScalingTechnique() != NORESCALE && cryptoParams->GetScalingTechnique() != FIXEDMANUAL && ciphertext->GetNoiseScaleDeg() == 2) { - ModReduceInternalInPlace(ciphertext, BASE_NUM_LEVELS_TO_DROP); + if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO || + cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) { + ModReduceInternalInPlace(ciphertext, cryptoParams->GetCompositeDegree()); + } + else { + ModReduceInternalInPlace(ciphertext, BASE_NUM_LEVELS_TO_DROP); + } } return EvalSquareCore(ciphertext); @@ -375,11 +390,26 @@ void LeveledSHERNS::LevelReduceInPlace(Ciphertext& ciphertext, const E // SHE LEVELED Compress ///////////////////////////////////////// +/* + * On COMPOSITESCALING technique, the number of towers to drop passed + * must be a multiple of composite degree. + */ Ciphertext LeveledSHERNS::Compress(ConstCiphertext ciphertext, size_t towersLeft) const { Ciphertext result = std::make_shared>(*ciphertext); + const auto cryptoParams = std::dynamic_pointer_cast(ciphertext->GetCryptoParameters()); + + usint levelsToDrop = BASE_NUM_LEVELS_TO_DROP; + if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO || + cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) { + usint compositeDegree = cryptoParams->GetCompositeDegree(); + levelsToDrop = compositeDegree; + if (towersLeft % compositeDegree != 0) { + OPENFHE_THROW("Number of towers to drop must be a multiple of composite degree."); + } + } while (result->GetNoiseScaleDeg() > 1) { - ModReduceInternalInPlace(result, BASE_NUM_LEVELS_TO_DROP); + ModReduceInternalInPlace(result, levelsToDrop); } const std::vector& cv = result->GetElements(); usint sizeQl = cv[0].GetNumOfElements(); @@ -507,4 +537,20 @@ void LeveledSHERNS::AdjustForMultInPlace(Ciphertext& ciphertext1, Ciph } } +Ciphertext LeveledSHERNS::ComposedEvalMult(ConstCiphertext ciphertext1, + ConstCiphertext ciphertext2, + const EvalKey evalKey) const { + auto algo = ciphertext1->GetCryptoContext()->GetScheme(); + const auto cryptoParams = std::dynamic_pointer_cast(ciphertext1->GetCryptoParameters()); + Ciphertext ciphertext = EvalMult(ciphertext1, ciphertext2); + algo->KeySwitchInPlace(ciphertext, evalKey); + uint32_t levelsToDrop = BASE_NUM_LEVELS_TO_DROP; + if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO || + cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) { + levelsToDrop = cryptoParams->GetCompositeDegree(); + } + ModReduceInPlace(ciphertext, levelsToDrop); + return ciphertext; +} + } // namespace lbcrypto