From fc5e6dd33f04671dd72c2963d4d2bedac5ee2246 Mon Sep 17 00:00:00 2001 From: ppppbamzy Date: Mon, 24 Oct 2022 15:29:28 +0800 Subject: [PATCH] [update] ckks example --- README.md | 17 ++++++++++------- examples/ckks_example.cpp | 13 +++++++++---- src/fhe/bgv/basics.cpp | 1 + src/fhe/primitives/keys.cpp | 1 + tests/bgv_t.cpp | 1 + tests/ckks_t.cpp | 1 + 6 files changed, 23 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 58b27f9..7a9e18f 100644 --- a/README.md +++ b/README.md @@ -43,20 +43,23 @@ int main() { int precision_bits = 30; auto params = ckks::create_params(4096, precision_bits); CkksSk sk(params); - auto relin_key = get_relin_key(sk); + auto relin_key = get_relin_key(sk, params.additional_mod); - auto pt = ckks::encode(1, params); - auto ct_sum = ckks::encrypt(pt, sk); - for (int i = 2; i <= 100000; i++) { + CkksCt ct_sum; + for (int i = 1; i <= 100000; i++) { auto pt = ckks::encode(1.0 / i, params); auto ct = ckks::encrypt(pt, sk); auto ct_squared = ckks::mult(ct, ct, relin_key); - ct_sum = ckks::add(ct_sum, ct_squared); + + if (i == 1) { + ct_sum = ct_squared; + } else { + ct_sum = ckks::add(ct_sum, ct_squared); + } } double sum = ckks::decode(ckks::decrypt(ct_sum, sk)); - std::cout << "(" << sum << ", " - << M_PI * M_PI / 6 << ")" << std::end; + std::cout << "(" << sum << ", " << M_PI * M_PI / 6 << ")" << std::endl; } ``` diff --git a/examples/ckks_example.cpp b/examples/ckks_example.cpp index ac76630..cb82e62 100644 --- a/examples/ckks_example.cpp +++ b/examples/ckks_example.cpp @@ -1,6 +1,7 @@ #include "ckks/ckks.h" #include #include +#include using namespace hehub; @@ -10,13 +11,17 @@ int main() { CkksSk sk(params); auto relin_key = get_relin_key(sk, params.additional_mod); - auto pt = ckks::encode(1, params); - auto ct_sum = ckks::encrypt(pt, sk); - for (int i = 2; i <= 100000; i++) { + CkksCt ct_sum; + for (int i = 1; i <= 100000; i++) { auto pt = ckks::encode(1.0 / i, params); auto ct = ckks::encrypt(pt, sk); auto ct_squared = ckks::mult(ct, ct, relin_key); - ct_sum = ckks::add(ct_sum, ct_squared); + + if (i == 1) { + ct_sum = ct_squared; + } else { + ct_sum = ckks::add(ct_sum, ct_squared); + } } double sum = ckks::decode(ckks::decrypt(ct_sum, sk)); diff --git a/src/fhe/bgv/basics.cpp b/src/fhe/bgv/basics.cpp index bd2873f..71b7218 100644 --- a/src/fhe/bgv/basics.cpp +++ b/src/fhe/bgv/basics.cpp @@ -1,5 +1,6 @@ #include "bgv.h" #include "common/mod_arith.h" +#include "common/ntt.h" #include "common/rns_transform.h" #include #include diff --git a/src/fhe/primitives/keys.cpp b/src/fhe/primitives/keys.cpp index 6ff4307..f2c752d 100644 --- a/src/fhe/primitives/keys.cpp +++ b/src/fhe/primitives/keys.cpp @@ -1,5 +1,6 @@ #include "keys.h" #include "common/mod_arith.h" +#include "common/ntt.h" #include "common/rns_transform.h" namespace hehub { diff --git a/tests/bgv_t.cpp b/tests/bgv_t.cpp index 2fb3af6..e0dbc32 100644 --- a/tests/bgv_t.cpp +++ b/tests/bgv_t.cpp @@ -1,5 +1,6 @@ #include "catch2/catch.hpp" #include "common/mod_arith.h" +#include "common/ntt.h" #include "common/sampling.h" #include "bgv/bgv.h" #include diff --git a/tests/ckks_t.cpp b/tests/ckks_t.cpp index 0b5caee..18ea5cf 100644 --- a/tests/ckks_t.cpp +++ b/tests/ckks_t.cpp @@ -2,6 +2,7 @@ #include "ckks/ckks.h" #include "common/bigint.h" #include "common/mod_arith.h" +#include "common/ntt.h" #include "common/permutation.h" #include "common/sampling.h" #include