From 364286eee13dcdb1e084a3e15c08cba72073bc22 Mon Sep 17 00:00:00 2001 From: Petar Ivanov <29689712+dartdart26@users.noreply.github.com> Date: Fri, 30 Jun 2023 14:53:19 +0200 Subject: [PATCH] Handle errors from tfhe-rs (#135) Overall strategy is: * we assume operations on well-formed ciphertexts don't fail * for operations we assume shouldn't fail, we use asserts * for operations that we assume can fail, we use return codes Essentially, allowing malformed ciphertexts from txns means these ops could fail at some point in the lifetime of a cihpertext and/or on ciphertexts that are produced from the original one: * deserialization or ciphertexts * decrypt * any FHE operation * serialization We assume the following ops don't fail if all inputs are well-formed: * deser of FHE keys * encryption We also assume that tfhe-rs failures are always deterministic. That allows us to not stop the node on such a failure and assume that all nodes have the same behaviour, leaving nodes in sync. Also, do not use Go finalizers anymore. Instead, ser/deser ciphertexts across the C/Go boundary. That avoids complications with finalizers and memory management. However, it has a performance overhead and we need to be extra careful that we free all C memory manually. --- core/vm/contracts.go | 270 ++---- core/vm/contracts_test.go | 152 ++-- core/vm/instructions.go | 15 +- core/vm/interpreter.go | 2 +- core/vm/tfhe.go | 1708 +++++++++++++++++++------------------ core/vm/tfhe_test.go | 170 ++-- 6 files changed, 1157 insertions(+), 1160 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 9eccabe1edaa..a941f25fc0a0 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -1217,11 +1217,6 @@ type tomlConfigOptions struct { OracleDBAddress string RequireRetryCount uint8 } - - Tfhe struct { - CiphertextsToGarbageCollect uint64 - CiphertextsGarbageCollectIntervalSecs uint64 - } } var tomlConfig tomlConfigOptions @@ -1270,10 +1265,12 @@ func init() { f, err := os.Open(home + "/.evmosd/zama/config/zama_config.toml") if err != nil { + fmt.Println("failed to open zama_config.toml file") return } defer f.Close() if err := toml.NewDecoder(f).Decode(&tomlConfig); err != nil { + fmt.Println("failed to parse zama_config.toml file: " + err.Error()) return } @@ -1281,16 +1278,19 @@ func init() { case "oracle": priv, err := os.ReadFile(home + "/.evmosd/zama/keys/signature-keys/private.ed25519") if err != nil { + fmt.Println("failed to read private.ed25519 file: " + err.Error()) return } privateSignatureKey = priv case "node": pub, err := os.ReadFile(home + "/.evmosd/zama/keys/signature-keys/public.ed25519") if err != nil { + fmt.Println("failed to read public.ed25519 file: " + err.Error()) return } publicSignatureKey = pub default: + fmt.Println("invalid oracle mode: " + mode) panic(fmt.Sprintf("invalid oracle mode: %s", mode)) } } @@ -1456,6 +1456,10 @@ var fheTrivialEncryptGasCosts = map[fheUintType]uint64{ FheUint32: params.FheUint32TrivialEncryptGas, } +func writeResult(ct *tfheCiphertext, fileName string, logger Logger) { + os.WriteFile("/tmp/"+fileName, ct.serialize(), 0644) +} + type fheAdd struct{} func (e *fheAdd) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { @@ -1520,11 +1524,7 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/add_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheAdd failed to write /tmp/add_result", "err", err) - return nil, err - } + writeResult(result, "add_result", logger) resultHash := result.getHash() logger.Info("fheAdd success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -1550,11 +1550,7 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/add_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheAdd scalar failed to write /tmp/add_result", "err", err) - return nil, err - } + writeResult(result, "add_scalar_result", logger) resultHash := result.getHash() logger.Info("fheAdd scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -1577,10 +1573,8 @@ func encryptToUserKey(value *big.Int, pubKey []byte) ([]byte, error) { } // TODO: for testing - err = os.WriteFile("/tmp/public_encrypt_result", ct, 0644) - if err != nil { - return nil, err - } + // Ignore file writing errors. + os.WriteFile("/tmp/public_encrypt_result", ct, 0644) return ct, nil } @@ -1611,11 +1605,21 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller } ctBytes := input[:len(input)-1] - ctType := fheUintType(input[len(input)-1]) - if !ctType.isValid() { - logger.Error("invalid type to cast to") - return nil, errors.New("invalid type provided") + ctTypeByte := input[len(input)-1] + if !isValidType(ctTypeByte) { + msg := "verifyCiphertext Run() ciphertext type is invalid" + logger.Error(msg, "type", ctTypeByte) + return nil, errors.New(msg) + } + ctType := fheUintType(ctTypeByte) + + expectedSize, found := compactFheCiphertextSize[ctType] + if !found || expectedSize != uint(len(ctBytes)) { + msg := "verifyCiphertext Run() compact ciphertext size is invalid" + logger.Error(msg, "type", ctTypeByte, "size", len(ctBytes), "expectedSize", expectedSize) + return nil, errors.New(msg) } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { return importRandomCiphertext(accessibleState, ctType), nil @@ -1682,7 +1686,11 @@ func (e *reencrypt) Run(accessibleState PrecompileAccessibleState, caller common } ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) if ct != nil { - decryptedValue := ct.ciphertext.decrypt() + decryptedValue, err := ct.ciphertext.decrypt() + if err != nil { + logger.Error("reencrypt decryption failed", "err", err) + return nil, err + } pubKey := input[32:64] reencryptedValue, err := encryptToUserKey(&decryptedValue, pubKey) if err != nil { @@ -1729,16 +1737,20 @@ func requireURL(key *string) string { // Puts the given ciphertext as a require to the oracle DB or exits the process on errors. // Returns the require value. -func putRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { +func putRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) (bool, error) { logger := interpreter.evm.Logger ciphertext := ct.serialize() - plaintext := ct.decrypt() + plaintext, err := ct.decrypt() + if err != nil { + logger.Error("putRequire decryption failed", "err", err) + return false, err + } value := (plaintext.BitLen() != 0) key := requireKey(ciphertext) j, err := json.Marshal(requireMessage{value, signRequire(ciphertext, value)}) if err != nil { logger.Error("putRequire JSON Marshal() failed, exiting process", "err", err, "key", key) - exitProcess() + return false, err } for try := uint8(1); try <= tomlConfig.Oracle.RequireRetryCount+1; try++ { req, err := http.NewRequest(http.MethodPut, requireURL(&key), bytes.NewReader(j)) @@ -1758,17 +1770,17 @@ func putRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { continue } logger.Info("putRequire sucess", "value", value, "key", key) - return value + return value, nil } logger.Error("putRequire reached maximum retries, exiting process", "retries", tomlConfig.Oracle.RequireRetryCount, "key", key) exitProcess() - return value + return value, nil } // Gets the given require from the oracle DB and returns its value. // Exits the process on errors or signature verification failure. -func getRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { +func getRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) (bool, error) { logger := interpreter.evm.Logger ciphertext := ct.serialize() key := requireKey(ciphertext) @@ -1776,7 +1788,7 @@ func getRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { req, err := http.NewRequest(http.MethodGet, requireURL(&key), http.NoBody) if err != nil { logger.Error("getRequire NewRequest() failed, retrying", "err", err) - continue + return false, err } resp, err := requireHttpClient.Do(req) if err != nil { @@ -1805,14 +1817,14 @@ func getRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { continue } logger.Info("getRequire success", "value", msg.Value, "key", key) - return msg.Value + return msg.Value, nil } logger.Error("getRequire reached maximum retries, exiting process", "retries", tomlConfig.Oracle.RequireRetryCount) exitProcess() - return false + return false, nil } -func evaluateRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { +func evaluateRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) (bool, error) { mode := strings.ToLower(tomlConfig.Oracle.Mode) switch mode { case "oracle": @@ -1822,7 +1834,7 @@ func evaluateRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool { } interpreter.evm.Logger.Error("evaluateRequire invalid mode", "mode", mode) exitProcess() - return false + return false, nil } func (e *require) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { @@ -1847,8 +1859,12 @@ func (e *require) Run(accessibleState PrecompileAccessibleState, caller common.A if !accessibleState.Interpreter().evm.Commit { return nil, nil } - if !evaluateRequire(ct.ciphertext, accessibleState.Interpreter()) { - accessibleState.Interpreter().evm.Logger.Error("require failed to evaluate, reverting") + value, err := evaluateRequire(ct.ciphertext, accessibleState.Interpreter()) + if err != nil { + accessibleState.Interpreter().evm.Logger.Error("require failed to evaluate, reverting", "err", err) + return nil, ErrExecutionReverted + } else if !value { + accessibleState.Interpreter().evm.Logger.Error("require value is false, reverting") return nil, ErrExecutionReverted } return nil, nil @@ -1988,11 +2004,7 @@ func (e *fheLe) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/le_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheLe failed to write /tmp/le_result", "err", err) - return nil, err - } + writeResult(result, "le_result", logger) resultHash := result.getHash() logger.Info("fheLe success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2018,11 +2030,7 @@ func (e *fheLe) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/le_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheLe scalar failed to write /tmp/le_result", "err", err) - return nil, err - } + writeResult(result, "le_scalar_result", logger) resultHash := result.getHash() logger.Info("fheLe scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2072,11 +2080,7 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/sub_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheSub failed to write /tmp/sub_result", "err", err) - return nil, err - } + writeResult(result, "sub_result", logger) resultHash := result.getHash() logger.Info("fheSub success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2102,11 +2106,7 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/sub_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheSub scalar failed to write /tmp/sub_result", "err", err) - return nil, err - } + writeResult(result, "sub_scalar_result", logger) resultHash := result.getHash() logger.Info("fheSub scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2178,11 +2178,7 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/mul_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheMul failed to write /tmp/mul_result", "err", err) - return nil, err - } + writeResult(result, "mul_result", logger) resultHash := result.getHash() logger.Info("fheMul success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2208,11 +2204,7 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/mul_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheMul scalar failed to write /tmp/mul_result", "err", err) - return nil, err - } + writeResult(result, "mul_scalar_result", logger) resultHash := result.getHash() logger.Info("fheMul scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2289,12 +2281,7 @@ func (e *fheBitAnd) Run(accessibleState PrecompileAccessibleState, caller common importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/bitand_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheBitAnd failed to write /tmp/bitand_result", "err", err) - return nil, err - } - + writeResult(result, "bitand_result", logger) resultHash := result.getHash() logger.Info("fheBitAnd success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) return resultHash[:], nil @@ -2348,11 +2335,7 @@ func (e *fheBitOr) Run(accessibleState PrecompileAccessibleState, caller common. importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/bitor_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheBitOr failed to write /tmp/bitor_result", "err", err) - return nil, err - } + writeResult(result, "bitor_result", logger) resultHash := result.getHash() logger.Info("fheBitOr success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2407,11 +2390,7 @@ func (e *fheBitXor) Run(accessibleState PrecompileAccessibleState, caller common importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/bitxor_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheBitXor failed to write /tmp/bitxor_result", "err", err) - return nil, err - } + writeResult(result, "bitxor_result", logger) resultHash := result.getHash() logger.Info("fheBitXor success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2482,11 +2461,7 @@ func (e *fheShl) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/shl_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheShl failed to write /tmp/shl_result", "err", err) - return nil, err - } + writeResult(result, "shl_result", logger) resultHash := result.getHash() logger.Info("fheShl success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2512,11 +2487,7 @@ func (e *fheShl) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/shl_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheShl scalar failed to write /tmp/shl_result", "err", err) - return nil, err - } + writeResult(result, "shl_scalar_result", logger) resultHash := result.getHash() logger.Info("fheShl scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2566,11 +2537,7 @@ func (e *fheShr) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/shr_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheShr failed to write /tmp/shr_result", "err", err) - return nil, err - } + writeResult(result, "shr_result", logger) resultHash := result.getHash() logger.Info("fheShr success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2596,11 +2563,7 @@ func (e *fheShr) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/shr_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheShr scalar failed to write /tmp/shr_result", "err", err) - return nil, err - } + writeResult(result, "shr_scalar_result", logger) resultHash := result.getHash() logger.Info("fheShr scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2650,11 +2613,7 @@ func (e *fheEq) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/eq_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheEq failed to write /tmp/eq_result", "err", err) - return nil, err - } + writeResult(result, "eq_result", logger) resultHash := result.getHash() logger.Info("fheEq success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2680,11 +2639,7 @@ func (e *fheEq) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/eq_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheEq scalar failed to write /tmp/eq_result", "err", err) - return nil, err - } + writeResult(result, "eq_scalar_result", logger) resultHash := result.getHash() logger.Info("fheEq scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2734,11 +2689,7 @@ func (e *fheNe) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/ne_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheNe failed to write /tmp/ne_result", "err", err) - return nil, err - } + writeResult(result, "ne_result", logger) resultHash := result.getHash() logger.Info("fheNe success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2764,11 +2715,7 @@ func (e *fheNe) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/ne_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheNe scalar failed to write /tmp/ne_result", "err", err) - return nil, err - } + writeResult(result, "ne_scalar_result", logger) resultHash := result.getHash() logger.Info("fheNe scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2818,11 +2765,7 @@ func (e *fheGe) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/ge_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheGe failed to write /tmp/ge_result", "err", err) - return nil, err - } + writeResult(result, "ge_result", logger) resultHash := result.getHash() logger.Info("fheGe success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2848,11 +2791,7 @@ func (e *fheGe) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/ge_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheGe scalar failed to write /tmp/ge_result", "err", err) - return nil, err - } + writeResult(result, "ge_scalar_result", logger) resultHash := result.getHash() logger.Info("fheGe scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2902,11 +2841,7 @@ func (e *fheGt) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/gt_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheGt failed to write /tmp/gt_result", "err", err) - return nil, err - } + writeResult(result, "gt_result", logger) resultHash := result.getHash() logger.Info("fheGt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -2932,11 +2867,7 @@ func (e *fheGt) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/gt_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheGt scalar failed to write /tmp/gt_result", "err", err) - return nil, err - } + writeResult(result, "gt_scalar_result", logger) resultHash := result.getHash() logger.Info("fheGt scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -2986,11 +2917,7 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/lt_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheLt failed to write /tmp/lt_result", "err", err) - return nil, err - } + writeResult(result, "lt_result", logger) resultHash := result.getHash() logger.Info("fheLt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -3016,11 +2943,7 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/lt_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheLt scalar failed to write /tmp/lt_result", "err", err) - return nil, err - } + writeResult(result, "lt_scalar_result", logger) resultHash := result.getHash() logger.Info("fheLt scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -3092,11 +3015,7 @@ func (e *fheMin) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/min_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheMin failed to write /tmp/min_result", "err", err) - return nil, err - } + writeResult(result, "min_result", logger) resultHash := result.getHash() logger.Info("fheMin success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -3122,11 +3041,7 @@ func (e *fheMin) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/min_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheMin scalar failed to write /tmp/min_result", "err", err) - return nil, err - } + writeResult(result, "min_scalar_result", logger) resultHash := result.getHash() logger.Info("fheMin scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -3176,11 +3091,7 @@ func (e *fheMax) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/max_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheMax failed to write /tmp/max_result", "err", err) - return nil, err - } + writeResult(result, "max_result", logger) resultHash := result.getHash() logger.Info("fheMax success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -3206,11 +3117,7 @@ func (e *fheMax) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/max_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheMax scalar failed to write /tmp/max_result", "err", err) - return nil, err - } + writeResult(result, "max_scalar_result", logger) resultHash := result.getHash() logger.Info("fheMax scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) @@ -3264,12 +3171,7 @@ func (e *fheNeg) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/neg_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheNeg failed to write /tmp/neg_result", "err", err) - return nil, err - } - + writeResult(result, "neg_result", logger) resultHash := result.getHash() logger.Info("fheNeg success", "ct", ct.ciphertext.getHash().Hex(), "result", resultHash.Hex()) return resultHash[:], nil @@ -3313,11 +3215,7 @@ func (e *fheNot) Run(accessibleState PrecompileAccessibleState, caller common.Ad importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/not_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheNot failed to write /tmp/not_result", "err", err) - return nil, err - } + writeResult(result, "not_result", logger) resultHash := result.getHash() logger.Info("fheNot success", "ct", ct.ciphertext.getHash().Hex(), "result", resultHash.Hex()) @@ -3429,11 +3327,11 @@ func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Addr return nil, errors.New("unverified ciphertext handle") } - castToType := fheUintType(input[32]) - if !castToType.isValid() { + if !isValidType(input[32]) { logger.Error("invalid type to cast to") return nil, errors.New("invalid type provided") } + castToType := fheUintType(input[32]) res, err := ct.ciphertext.castTo(castToType) if err != nil { diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 83793577bd59..3c8ebad2a51d 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -597,8 +597,8 @@ func FheAdd(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -639,8 +639,8 @@ func FheSub(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -681,8 +681,8 @@ func FheMul(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -728,8 +728,8 @@ func FheBitAnd(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -776,8 +776,8 @@ func FheBitOr(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -824,8 +824,8 @@ func FheBitXor(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -867,8 +867,8 @@ func FheShl(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -909,8 +909,8 @@ func FheShr(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -951,8 +951,8 @@ func FheEq(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } } @@ -993,8 +993,8 @@ func FheNe(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } } @@ -1035,8 +1035,8 @@ func FheGe(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } // Inverting operands is only possible in the non scalar case as scalar @@ -1052,8 +1052,8 @@ func FheGe(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } } @@ -1096,8 +1096,8 @@ func FheGt(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } @@ -1114,8 +1114,8 @@ func FheGt(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } } @@ -1158,8 +1158,8 @@ func FheLe(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } @@ -1176,8 +1176,8 @@ func FheLe(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } } @@ -1221,8 +1221,8 @@ func FheLt(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } @@ -1239,8 +1239,8 @@ func FheLt(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } } @@ -1283,8 +1283,8 @@ func FheMin(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != rhs { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != rhs { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), rhs) } @@ -1300,8 +1300,8 @@ func FheMin(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != rhs { + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != rhs { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), rhs) } } @@ -1344,8 +1344,8 @@ func FheMax(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != lhs { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != lhs { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), lhs) } @@ -1361,8 +1361,8 @@ func FheMax(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != lhs { + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != lhs { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), lhs) } } @@ -1400,8 +1400,8 @@ func FheNeg(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -1438,12 +1438,44 @@ func FheNot(t *testing.T, fheUintType fheUintType, scalar bool) { if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } +func TestVerifyCiphertextInvalidType(t *testing.T) { + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + invalidType := fheUintType(255) + compact := encryptAndSerializeCompact(0, FheUint32) + input := append(compact, byte(invalidType)) + _, err := c.Run(state, addr, addr, input, readOnly) + if err == nil { + t.Fatalf("verifyCiphertext must have failed on invalid ciphertext type") + } +} + +func TestVerifyCiphertextInvalidSize(t *testing.T) { + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + ctType := FheUint32 + compact := encryptAndSerializeCompact(0, ctType) + input := append(compact[:len(compact)-1], byte(ctType)) + _, err := c.Run(state, addr, addr, input, readOnly) + if err == nil { + t.Fatalf("verifyCiphertext must have failed on invalid ciphertext size") + } +} + func TestVerifyCiphertext8(t *testing.T) { VerifyCiphertext(t, FheUint8) } @@ -1468,20 +1500,20 @@ func TestTrivialEncrypt32(t *testing.T) { TrivialEncrypt(t, FheUint32) } -// func TestVerifyCiphertext8BadType(t *testing.T) { -// VerifyCiphertextBadType(t, FheUint8, FheUint16) -// VerifyCiphertextBadType(t, FheUint8, FheUint32) -// } +func TestVerifyCiphertext8BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint8, FheUint16) + VerifyCiphertextBadType(t, FheUint8, FheUint32) +} -// func TestVerifyCiphertext16BadType(t *testing.T) { -// VerifyCiphertextBadType(t, FheUint16, FheUint8) -// VerifyCiphertextBadType(t, FheUint16, FheUint32) -// } +func TestVerifyCiphertext16BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint16, FheUint8) + VerifyCiphertextBadType(t, FheUint16, FheUint32) +} -// func TestVerifyCiphertext32BadType(t *testing.T) { -// VerifyCiphertextBadType(t, FheUint32, FheUint8) -// VerifyCiphertextBadType(t, FheUint32, FheUint16) -// } +func TestVerifyCiphertext32BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint32, FheUint8) + VerifyCiphertextBadType(t, FheUint32, FheUint16) +} func TestVerifyCiphertextBadCiphertext(t *testing.T) { c := &verifyCiphertext{} diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 1516b9aa792e..fa9bb3b6acc3 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -18,6 +18,7 @@ package vm import ( "encoding/hex" + "errors" "sync/atomic" "github.com/ethereum/go-ethereum/common" @@ -559,7 +560,7 @@ func newInt(buf []byte) *uint256.Int { var zero = uint256.NewInt(0).Bytes32() -func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, contractAddress common.Address) { +func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, contractAddress common.Address) error { ct, ok := interpreter.verifiedCiphertexts[val] if ok { // If already existing in memory, skip storage and import the same ciphertext at the current depth. @@ -568,7 +569,7 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont // However, ciphertexts remain in memory for the duration of the call, allowing for this lookup to find it. // Note that even if a ciphertext has an empty verification depth set, it still remains in memory. importCiphertextToEVM(interpreter, ct.ciphertext) - return + return nil } protectedStorage := crypto.CreateProtectedStorageContractAddress(contractAddress) @@ -593,18 +594,22 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont ct := new(tfheCiphertext) err := ct.deserialize(ctBytes, metadata.fheUintType) if err != nil { - interpreter.evm.Logger.Error("opSload failed to deserialize a ciphertext, exiting process", "err", err) - exitProcess() + msg := "opSload failed to deserialize a ciphertext" + interpreter.evm.Logger.Error(msg, "err", err) + return errors.New(msg) } importCiphertextToEVM(interpreter, ct) } + return nil } func opSload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { loc := scope.Stack.peek() hash := common.Hash(loc.Bytes32()) val := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) - verifyIfCiphertextHandle(val, interpreter, scope.Contract.Address()) + if err := verifyIfCiphertextHandle(val, interpreter, scope.Contract.Address()); err != nil { + return nil, err + } loc.SetBytes(val.Bytes()) return nil, nil } diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index aef5e57867e8..e2acda0f1aa9 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -295,7 +295,7 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( // is an optimistic require, check its decrypted value. If false, return as if // execution is to be reverted. if in.evm.depth == 1 && in.optimisticRequire != nil { - if !evaluateRequire(in.optimisticRequire, in) { + if value, evalError := evaluateRequire(in.optimisticRequire, in); evalError != nil || !value { err = ErrExecutionReverted } } diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index 888106fb650f..3ca8e7225b68 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -51,9 +51,8 @@ void checked_set_server_key(void *sks) { assert(r == 0); } -void serialize_fhe_uint8(void *ct, Buffer* out) { - const int r = fhe_uint8_serialize(ct, out); - assert(r == 0); +int serialize_fhe_uint8(void *ct, Buffer* out) { + return fhe_uint8_serialize(ct, out); } void* deserialize_fhe_uint8(BufferView in) { @@ -90,9 +89,8 @@ void* deserialize_compact_fhe_uint8(BufferView in) { return ct; } -void serialize_fhe_uint16(void *ct, Buffer* out) { - const int r = fhe_uint16_serialize(ct, out); - assert(r == 0); +int serialize_fhe_uint16(void *ct, Buffer* out) { + return fhe_uint16_serialize(ct, out); } void* deserialize_fhe_uint16(BufferView in) { @@ -129,9 +127,8 @@ void* deserialize_compact_fhe_uint16(BufferView in) { return ct; } -void serialize_fhe_uint32(void *ct, Buffer* out) { - const int r = fhe_uint32_serialize(ct, out); - assert(r == 0); +int serialize_fhe_uint32(void *ct, Buffer* out) { + return fhe_uint32_serialize(ct, out); } void* deserialize_fhe_uint32(BufferView in) { @@ -169,15 +166,18 @@ void* deserialize_compact_fhe_uint32(BufferView in) { } void destroy_fhe_uint8(void* ct) { - fhe_uint8_destroy(ct); + const int r = fhe_uint8_destroy(ct); + assert(r == 0); } void destroy_fhe_uint16(void* ct) { - fhe_uint16_destroy(ct); + const int r = fhe_uint16_destroy(ct); + assert(r == 0); } void destroy_fhe_uint32(void* ct) { - fhe_uint32_destroy(ct); + const int r = fhe_uint32_destroy(ct); + assert(r == 0); } void* add_fhe_uint8(void* ct1, void* ct2, void* sks) @@ -187,7 +187,7 @@ void* add_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_add(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -198,7 +198,7 @@ void* add_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_add(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -209,7 +209,7 @@ void* add_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_add(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -220,7 +220,7 @@ void* scalar_add_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_add(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -231,7 +231,7 @@ void* scalar_add_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_add(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -242,7 +242,7 @@ void* scalar_add_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_add(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -253,7 +253,7 @@ void* sub_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_sub(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -264,7 +264,7 @@ void* sub_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_sub(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -275,7 +275,7 @@ void* sub_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_sub(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -286,7 +286,7 @@ void* scalar_sub_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_sub(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -297,7 +297,7 @@ void* scalar_sub_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_sub(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -308,7 +308,7 @@ void* scalar_sub_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_sub(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -319,7 +319,7 @@ void* mul_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_mul(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -330,7 +330,7 @@ void* mul_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_mul(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -341,7 +341,7 @@ void* mul_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_mul(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -352,7 +352,7 @@ void* scalar_mul_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_mul(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -363,7 +363,7 @@ void* scalar_mul_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_mul(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -374,7 +374,7 @@ void* scalar_mul_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_mul(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -385,7 +385,7 @@ void* bitand_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_bitand(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -396,7 +396,7 @@ void* bitand_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_bitand(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -407,7 +407,7 @@ void* bitand_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_bitand(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -418,7 +418,7 @@ void* bitor_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_bitor(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -429,7 +429,7 @@ void* bitor_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_bitor(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -440,7 +440,7 @@ void* bitor_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_bitor(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -451,7 +451,7 @@ void* bitxor_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_bitxor(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -462,7 +462,7 @@ void* bitxor_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_bitxor(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -473,7 +473,7 @@ void* bitxor_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_bitxor(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -484,7 +484,7 @@ void* shl_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_shl(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -495,7 +495,7 @@ void* shl_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_shl(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -506,7 +506,7 @@ void* shl_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_shl(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -517,7 +517,7 @@ void* scalar_shl_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_shl(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -528,7 +528,7 @@ void* scalar_shl_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_shl(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -539,7 +539,7 @@ void* scalar_shl_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_shl(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -550,7 +550,7 @@ void* shr_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_shr(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -561,7 +561,7 @@ void* shr_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_shr(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -572,7 +572,7 @@ void* shr_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_shr(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -583,7 +583,7 @@ void* scalar_shr_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_shr(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -594,7 +594,7 @@ void* scalar_shr_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_shr(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -605,7 +605,7 @@ void* scalar_shr_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_shr(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -616,7 +616,7 @@ void* eq_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_eq(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -627,7 +627,7 @@ void* eq_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_eq(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -638,7 +638,7 @@ void* eq_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_eq(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -649,7 +649,7 @@ void* scalar_eq_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_eq(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -660,7 +660,7 @@ void* scalar_eq_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_eq(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -671,7 +671,7 @@ void* scalar_eq_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_eq(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -682,7 +682,7 @@ void* ne_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_ne(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -693,7 +693,7 @@ void* ne_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_ne(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -704,7 +704,7 @@ void* ne_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_ne(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -715,7 +715,7 @@ void* scalar_ne_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_ne(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -726,7 +726,7 @@ void* scalar_ne_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_ne(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -737,7 +737,7 @@ void* scalar_ne_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_ne(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -748,7 +748,7 @@ void* ge_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_ge(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -759,7 +759,7 @@ void* ge_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_ge(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -770,7 +770,7 @@ void* ge_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_ge(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -781,7 +781,7 @@ void* scalar_ge_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_ge(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -792,7 +792,7 @@ void* scalar_ge_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_ge(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -803,7 +803,7 @@ void* scalar_ge_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_ge(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -814,7 +814,7 @@ void* gt_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_gt(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -825,7 +825,7 @@ void* gt_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_gt(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -836,7 +836,7 @@ void* gt_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_gt(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -847,7 +847,7 @@ void* scalar_gt_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_gt(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -858,7 +858,7 @@ void* scalar_gt_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_gt(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -869,7 +869,7 @@ void* scalar_gt_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_gt(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -880,7 +880,7 @@ void* le_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_le(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -891,7 +891,7 @@ void* le_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_le(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -902,7 +902,7 @@ void* le_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_le(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -913,7 +913,7 @@ void* lt_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_lt(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -924,7 +924,7 @@ void* scalar_le_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_le(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -935,7 +935,7 @@ void* scalar_le_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_le(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -946,7 +946,7 @@ void* scalar_le_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_le(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -957,7 +957,7 @@ void* lt_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_lt(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -968,7 +968,7 @@ void* lt_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_lt(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -979,7 +979,7 @@ void* scalar_lt_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_lt(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -990,7 +990,7 @@ void* scalar_lt_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_lt(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1001,7 +1001,7 @@ void* scalar_lt_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_lt(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1012,7 +1012,7 @@ void* min_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_min(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1023,7 +1023,7 @@ void* min_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_min(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1034,7 +1034,7 @@ void* min_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_min(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1045,7 +1045,7 @@ void* scalar_min_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_min(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1056,7 +1056,7 @@ void* scalar_min_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_min(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1067,7 +1067,7 @@ void* scalar_min_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_min(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1078,7 +1078,7 @@ void* max_fhe_uint8(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_max(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1089,7 +1089,7 @@ void* max_fhe_uint16(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_max(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1100,7 +1100,7 @@ void* max_fhe_uint32(void* ct1, void* ct2, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_max(ct1, ct2, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1111,7 +1111,7 @@ void* scalar_max_fhe_uint8(void* ct, uint8_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint8_scalar_max(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1122,7 +1122,7 @@ void* scalar_max_fhe_uint16(void* ct, uint16_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint16_scalar_max(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1133,7 +1133,7 @@ void* scalar_max_fhe_uint32(void* ct, uint32_t pt, void* sks) checked_set_server_key(sks); const int r = fhe_uint32_scalar_max(ct, pt, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1143,7 +1143,7 @@ void* neg_fhe_uint8(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint8_neg(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1153,7 +1153,7 @@ void* neg_fhe_uint16(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint16_neg(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1163,7 +1163,7 @@ void* neg_fhe_uint32(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint32_neg(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1173,7 +1173,7 @@ void* not_fhe_uint8(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint8_not(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1183,7 +1183,7 @@ void* not_fhe_uint16(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint16_not(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1193,32 +1193,26 @@ void* not_fhe_uint32(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint32_not(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } -uint8_t decrypt_fhe_uint8(void* cks, void* ct) +int decrypt_fhe_uint8(void* cks, void* ct, uint8_t* res) { - uint8_t res = 0; - const int r = fhe_uint8_decrypt(ct, cks, &res); - assert(r == 0); - return res; + *res = 0; + return fhe_uint8_decrypt(ct, cks, res); } -uint16_t decrypt_fhe_uint16(void* cks, void* ct) +int decrypt_fhe_uint16(void* cks, void* ct, uint16_t* res) { - uint16_t res = 0; - const int r = fhe_uint16_decrypt(ct, cks, &res); - assert(r == 0); - return res; + *res = 0; + return fhe_uint16_decrypt(ct, cks, res); } -uint32_t decrypt_fhe_uint32(void* cks, void* ct) +int decrypt_fhe_uint32(void* cks, void* ct, uint32_t* res) { - uint32_t res = 0; - const int r = fhe_uint32_decrypt(ct, cks, &res); - assert(r == 0); - return res; + *res = 0; + return fhe_uint32_decrypt(ct, cks, res); } void* public_key_encrypt_fhe_uint8(void* pks, uint8_t value) { @@ -1338,7 +1332,7 @@ void* cast_8_16(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint8_cast_into_fhe_uint16(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1348,7 +1342,7 @@ void* cast_8_32(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint8_cast_into_fhe_uint32(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1358,7 +1352,7 @@ void* cast_16_8(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint16_cast_into_fhe_uint8(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1368,7 +1362,7 @@ void* cast_16_32(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint16_cast_into_fhe_uint32(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1378,7 +1372,7 @@ void* cast_32_8(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint32_cast_into_fhe_uint8(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1388,7 +1382,7 @@ void* cast_32_16(void* ct, void* sks) { checked_set_server_key(sks); const int r = fhe_uint32_cast_into_fhe_uint16(ct, &result); - assert(r == 0); + if(r != 0) return NULL; return result; } @@ -1400,9 +1394,6 @@ import ( "fmt" "math/big" "os" - "runtime" - "sync/atomic" - "time" "unsafe" "github.com/ethereum/go-ethereum/common" @@ -1424,10 +1415,12 @@ func homeDir() string { return home } -// TFHE ciphertext sizes by type, in bytes. -// Note: These sizes are for expanded (non-compacted) ciphertexts. +// Expanded TFHE ciphertext sizes by type, in bytes. var expandedFheCiphertextSize map[fheUintType]uint +// Compact TFHE ciphertext sizes by type, in bytes. +var compactFheCiphertextSize map[fheUintType]uint + var sks unsafe.Pointer var cks unsafe.Pointer var pks unsafe.Pointer @@ -1436,23 +1429,9 @@ var pksHash common.Hash var networkKeysDir string var usersKeysDir string -var allocatedCiphertexts uint64 - -// TODO: We assume that contracts.go's init() runs before the init() in this file, -// making the TOML configuration available here. -func runGc() { - for range time.Tick(time.Duration(tomlConfig.Tfhe.CiphertextsGarbageCollectIntervalSecs) * time.Second) { - if atomic.LoadUint64(&allocatedCiphertexts) >= tomlConfig.Tfhe.CiphertextsToGarbageCollect { - atomic.StoreUint64(&allocatedCiphertexts, 0) - runtime.GC() - } - } -} - func init() { expandedFheCiphertextSize = make(map[fheUintType]uint) - - go runGc() + compactFheCiphertextSize = make(map[fheUintType]uint) home := homeDir() networkKeysDir = home + "/.evmosd/zama/keys/network-fhe-keys/" @@ -1469,13 +1448,6 @@ func init() { expandedFheCiphertextSize[FheUint16] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint16).serialize())) expandedFheCiphertextSize[FheUint32] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint32).serialize())) - cksBytes, err := os.ReadFile(networkKeysDir + "cks") - if err != nil { - fmt.Println("WARNING: file cks not found.") - return - } - cks = C.deserialize_client_key(toBufferView(cksBytes)) - pksBytes, err = os.ReadFile(networkKeysDir + "pks") if err != nil { pksBytes = nil @@ -1484,6 +1456,38 @@ func init() { } pksHash = crypto.Keccak256Hash(pksBytes) pks = C.deserialize_compact_public_key(toBufferView(pksBytes)) + + compactFheCiphertextSize[FheUint8] = uint(len(encryptAndSerializeCompact(0, FheUint8))) + compactFheCiphertextSize[FheUint16] = uint(len(encryptAndSerializeCompact(0, FheUint16))) + compactFheCiphertextSize[FheUint32] = uint(len(encryptAndSerializeCompact(0, FheUint32))) + + cksBytes, err := os.ReadFile(networkKeysDir + "cks") + if err != nil { + fmt.Println("WARNING: file cks not found.") + return + } + cks = C.deserialize_client_key(toBufferView(cksBytes)) +} + +func serialize(ptr unsafe.Pointer, t fheUintType) ([]byte, error) { + out := &C.Buffer{} + var ret C.int + switch t { + case FheUint8: + ret = C.serialize_fhe_uint8(ptr, out) + case FheUint16: + ret = C.serialize_fhe_uint16(ptr, out) + case FheUint32: + ret = C.serialize_fhe_uint32(ptr, out) + default: + panic("serialize: unexpected ciphertext type") + } + if ret != 0 { + return nil, errors.New("serialize: failed to serialize a ciphertext") + } + ser := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) + C.destroy_buffer(out) + return ser, nil } // Represents a TFHE ciphertext type, i.e. its bit capacity. @@ -1496,796 +1500,770 @@ const ( ) // Represents an expanded TFHE ciphertext. -// -// Once a ciphertext has a value (i.e. from deserialization), it must not be set -// another value. If that is needed, a new ciphertext must be created. type tfheCiphertext struct { - ptr unsafe.Pointer serialization []byte hash *common.Hash - value *big.Int fheUintType fheUintType } // Deserializes a TFHE ciphertext. func (ct *tfheCiphertext) deserialize(in []byte, t fheUintType) error { - if ct.initialized() { - panic("cannot deserialize to an existing ciphertext") - } - var ptr unsafe.Pointer switch t { case FheUint8: - ptr = C.deserialize_fhe_uint8(toBufferView((in))) + ptr := C.deserialize_fhe_uint8(toBufferView((in))) + if ptr == nil { + return errors.New("FheUint8 ciphertext deserialization failed") + } + C.destroy_fhe_uint8(ptr) case FheUint16: - ptr = C.deserialize_fhe_uint16(toBufferView((in))) + ptr := C.deserialize_fhe_uint16(toBufferView((in))) + if ptr == nil { + return errors.New("FheUint16 ciphertext deserialization failed") + } + C.destroy_fhe_uint16(ptr) case FheUint32: - ptr = C.deserialize_fhe_uint32(toBufferView((in))) - } - if ptr == nil { - return errors.New("TFHE ciphertext deserialization failed") + ptr := C.deserialize_fhe_uint32(toBufferView((in))) + if ptr == nil { + return errors.New("FheUint32 ciphertext deserialization failed") + } + C.destroy_fhe_uint32(ptr) + default: + panic("deserialize: unexpected ciphertext type") } - ct.setPtr(ptr) ct.fheUintType = t ct.serialization = in + ct.computeHash() return nil } // Deserializes a compact TFHE ciphetext. -// Note: After the compact thfe ciphertext has been serialized, subsequent calls to serialize() +// Note: After the compact TFHE ciphertext has been serialized, subsequent calls to serialize() // will produce non-compact ciphertext serialziations. func (ct *tfheCiphertext) deserializeCompact(in []byte, t fheUintType) error { - if ct.initialized() { - panic("cannot deserialize to an existing ciphertext") - } - var ptr unsafe.Pointer switch t { case FheUint8: - ptr = C.deserialize_compact_fhe_uint8(toBufferView((in))) + ptr := C.deserialize_compact_fhe_uint8(toBufferView((in))) + if ptr == nil { + return errors.New("compact FheUint8 ciphertext deserialization failed") + } + var err error + ct.serialization, err = serialize(ptr, t) + if err != nil { + return err + } + C.destroy_fhe_uint8(ptr) case FheUint16: - ptr = C.deserialize_compact_fhe_uint16(toBufferView((in))) + ptr := C.deserialize_compact_fhe_uint16(toBufferView((in))) + if ptr == nil { + return errors.New("compact FheUint16 ciphertext deserialization failed") + } + var err error + ct.serialization, err = serialize(ptr, t) + if err != nil { + return err + } + C.destroy_fhe_uint16(ptr) case FheUint32: - ptr = C.deserialize_compact_fhe_uint32(toBufferView((in))) - } - if ptr == nil { - return errors.New("TFHE ciphertext deserialization failed") + ptr := C.deserialize_compact_fhe_uint32(toBufferView((in))) + if ptr == nil { + return errors.New("compact FheUint32 ciphertext deserialization failed") + } + var err error + ct.serialization, err = serialize(ptr, t) + if err != nil { + return err + } + C.destroy_fhe_uint32(ptr) + default: + panic("deserializeCompact: unexpected ciphertext type") } - ct.setPtr(ptr) ct.fheUintType = t + ct.computeHash() return nil } // Encrypts a value as a TFHE ciphertext, using the compact public FHE key. // The resulting ciphertext is automaticaly expanded. func (ct *tfheCiphertext) encrypt(value big.Int, t fheUintType) *tfheCiphertext { - if ct.initialized() { - panic("cannot encrypt to an existing ciphertext") - } - + var ptr unsafe.Pointer switch t { case FheUint8: - ct.setPtr(C.public_key_encrypt_fhe_uint8(pks, C.uint8_t(value.Uint64()))) + ptr = C.public_key_encrypt_fhe_uint8(pks, C.uint8_t(value.Uint64())) case FheUint16: - ct.setPtr(C.public_key_encrypt_fhe_uint16(pks, C.uint16_t(value.Uint64()))) + ptr = C.public_key_encrypt_fhe_uint16(pks, C.uint16_t(value.Uint64())) case FheUint32: - ct.setPtr(C.public_key_encrypt_fhe_uint32(pks, C.uint32_t(value.Uint64()))) + ptr = C.public_key_encrypt_fhe_uint32(pks, C.uint32_t(value.Uint64())) + default: + panic("encrypt: unexpected ciphertext type") + } + var err error + ct.serialization, err = serialize(ptr, t) + if err != nil { + panic(err) } ct.fheUintType = t - ct.value = &value + ct.computeHash() return ct } func (ct *tfheCiphertext) trivialEncrypt(value big.Int, t fheUintType) *tfheCiphertext { - if ct.initialized() { - panic("cannot encrypt to an existing ciphertext") - } - + var ptr unsafe.Pointer switch t { case FheUint8: - ct.setPtr(C.trivial_encrypt_fhe_uint8(sks, C.uint8_t(value.Uint64()))) + ptr = C.trivial_encrypt_fhe_uint8(sks, C.uint8_t(value.Uint64())) case FheUint16: - ct.setPtr(C.trivial_encrypt_fhe_uint16(sks, C.uint16_t(value.Uint64()))) + ptr = C.trivial_encrypt_fhe_uint16(sks, C.uint16_t(value.Uint64())) case FheUint32: - ct.setPtr(C.trivial_encrypt_fhe_uint32(sks, C.uint32_t(value.Uint64()))) + ptr = C.trivial_encrypt_fhe_uint32(sks, C.uint32_t(value.Uint64())) + default: + panic("trivialEncrypt: unexpected ciphertext type") + } + var err error + ct.serialization, err = serialize(ptr, ct.fheUintType) + if err != nil { + panic(err) } ct.fheUintType = t - ct.value = &value + ct.computeHash() return ct } func (ct *tfheCiphertext) serialize() []byte { - if !ct.initialized() { - panic("cannot serialize a non-initialized ciphertext") - } else if ct.serialization != nil { - return ct.serialization - } - out := &C.Buffer{} + return ct.serialization +} + +func (ct *tfheCiphertext) executeUnaryCiphertextOperation(rhs *tfheCiphertext, + op8 func(ct unsafe.Pointer) unsafe.Pointer, + op16 func(ct unsafe.Pointer) unsafe.Pointer, + op32 func(ct unsafe.Pointer) unsafe.Pointer) (*tfheCiphertext, error) { + + res := new(tfheCiphertext) + res.fheUintType = ct.fheUintType + res_ser := &C.Buffer{} switch ct.fheUintType { case FheUint8: - C.serialize_fhe_uint8(ct.ptr, out) + ct_ptr := C.deserialize_fhe_uint8(toBufferView((ct.serialization))) + if ct_ptr == nil { + return nil, errors.New("8 bit unary op deserialization failed") + } + res_ptr := op8(ct_ptr) + C.destroy_fhe_uint8(ct_ptr) + if res_ptr == nil { + return nil, errors.New("8 bit unary op failed") + } + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint8(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit unary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) case FheUint16: - C.serialize_fhe_uint16(ct.ptr, out) + ct_ptr := C.deserialize_fhe_uint16(toBufferView((ct.serialization))) + if ct_ptr == nil { + return nil, errors.New("16 bit unary op deserialization failed") + } + res_ptr := op16(ct_ptr) + C.destroy_fhe_uint16(ct_ptr) + if res_ptr == nil { + return nil, errors.New("16 bit op failed") + } + ret := C.serialize_fhe_uint16(res_ptr, res_ser) + C.destroy_fhe_uint16(res_ptr) + if ret != 0 { + return nil, errors.New("16 bit unary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) case FheUint32: - C.serialize_fhe_uint32(ct.ptr, out) + ct_ptr := C.deserialize_fhe_uint32(toBufferView((ct.serialization))) + if ct_ptr == nil { + return nil, errors.New("32 bit unary op deserialization failed") + } + res_ptr := op32(ct_ptr) + C.destroy_fhe_uint32(ct_ptr) + if res_ptr == nil { + return nil, errors.New("32 bit op failed") + } + ret := C.serialize_fhe_uint32(res_ptr, res_ser) + C.destroy_fhe_uint32(res_ptr) + if ret != 0 { + return nil, errors.New("32 bit unary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) + default: + panic("unary op unexpected ciphertext type") } - ct.serialization = C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) - C.destroy_buffer(out) - return ct.serialization + res.computeHash() + return res, nil } -func (lhs *tfheCiphertext) add(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot add on a non-initialized ciphertext") - } - +func (lhs *tfheCiphertext) executeBinaryCiphertextOperation(rhs *tfheCiphertext, + op8 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, + op16 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, + op32 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer) (*tfheCiphertext, error) { if lhs.fheUintType != rhs.fheUintType { return nil, errors.New("binary operations are only well-defined for identical types") } res := new(tfheCiphertext) res.fheUintType = lhs.fheUintType + res_ser := &C.Buffer{} switch lhs.fheUintType { case FheUint8: - res.setPtr(C.add_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + lhs_ptr := C.deserialize_fhe_uint8(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("8 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint8(toBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint8(lhs_ptr) + return nil, errors.New("8 bit binary op deserialization failed") + } + res_ptr := op8(lhs_ptr, rhs_ptr) + C.destroy_fhe_uint8(lhs_ptr) + C.destroy_fhe_uint8(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("8 bit binary op failed") + } + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint8(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) case FheUint16: - res.setPtr(C.add_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + lhs_ptr := C.deserialize_fhe_uint16(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("16 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint16(toBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint16(lhs_ptr) + return nil, errors.New("16 bit binary op deserialization failed") + } + res_ptr := op16(lhs_ptr, rhs_ptr) + C.destroy_fhe_uint16(lhs_ptr) + C.destroy_fhe_uint16(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("16 bit binary op failed") + } + ret := C.serialize_fhe_uint16(res_ptr, res_ser) + C.destroy_fhe_uint16(res_ptr) + if ret != 0 { + return nil, errors.New("16 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) case FheUint32: - res.setPtr(C.add_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + lhs_ptr := C.deserialize_fhe_uint32(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("32 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint32(toBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint32(lhs_ptr) + return nil, errors.New("32 bit binary op deserialization failed") + } + res_ptr := op32(lhs_ptr, rhs_ptr) + C.destroy_fhe_uint32(lhs_ptr) + C.destroy_fhe_uint32(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("32 bit binary op failed") + } + ret := C.serialize_fhe_uint32(res_ptr, res_ser) + C.destroy_fhe_uint32(res_ptr) + if ret != 0 { + return nil, errors.New("32 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) + default: + panic("binary op unexpected ciphertext type") } + res.computeHash() return res, nil } -func (lhs *tfheCiphertext) scalarAdd(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar add on a non-initialized ciphertext") - } - +func (lhs *tfheCiphertext) executeBinaryScalarOperation(rhs uint64, + op8 func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer, + op16 func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer, + op32 func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer) (*tfheCiphertext, error) { res := new(tfheCiphertext) res.fheUintType = lhs.fheUintType + res_ser := &C.Buffer{} switch lhs.fheUintType { case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_add_fhe_uint8(lhs.ptr, pt, sks)) + lhs_ptr := C.deserialize_fhe_uint8(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("8 bit scalar op deserialization failed") + } + scalar := C.uint8_t(rhs) + res_ptr := op8(lhs_ptr, scalar) + C.destroy_fhe_uint8(lhs_ptr) + if res_ptr == nil { + return nil, errors.New("8 bit scalar op failed") + } + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint8(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit scalar op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_add_fhe_uint16(lhs.ptr, pt, sks)) + lhs_ptr := C.deserialize_fhe_uint16(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("16 bit scalar op deserialization failed") + } + scalar := C.uint16_t(rhs) + res_ptr := op16(lhs_ptr, scalar) + C.destroy_fhe_uint16(lhs_ptr) + if res_ptr == nil { + return nil, errors.New("16 bit scalar op failed") + } + ret := C.serialize_fhe_uint16(res_ptr, res_ser) + C.destroy_fhe_uint16(res_ptr) + if ret != 0 { + return nil, errors.New("16 bit scalar op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_add_fhe_uint32(lhs.ptr, pt, sks)) + lhs_ptr := C.deserialize_fhe_uint32(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("32 bit scalar op deserialization failed") + } + scalar := C.uint32_t(rhs) + res_ptr := op32(lhs_ptr, scalar) + C.destroy_fhe_uint32(lhs_ptr) + if res_ptr == nil { + return nil, errors.New("32 bit scalar op failed") + } + ret := C.serialize_fhe_uint32(res_ptr, res_ser) + C.destroy_fhe_uint32(res_ptr) + if ret != 0 { + return nil, errors.New("32 bit scalar op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) + default: + panic("scalar op unexpected ciphertext type") } + res.computeHash() return res, nil } -func (lhs *tfheCiphertext) sub(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot sub on a non-initialized ciphertext") - } +func (lhs *tfheCiphertext) add(rhs *tfheCiphertext) (*tfheCiphertext, error) { + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.add_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.add_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.add_fhe_uint32(lhs, rhs, sks) + }) +} - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } +func (lhs *tfheCiphertext) scalarAdd(rhs uint64) (*tfheCiphertext, error) { + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_add_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_add_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_add_fhe_uint32(lhs, rhs, sks) + }) +} - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.sub_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.sub_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.sub_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil +func (lhs *tfheCiphertext) sub(rhs *tfheCiphertext) (*tfheCiphertext, error) { + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.sub_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.sub_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.sub_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarSub(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar sub on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_sub_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_sub_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_sub_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_sub_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_sub_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_sub_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot mul on a non-initialized ciphertext") - } + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.mul_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.mul_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.mul_fhe_uint32(lhs, rhs, sks) + }) +} - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.mul_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.mul_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.mul_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil -} - -func (lhs *tfheCiphertext) scalarMul(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar mul on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_mul_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_mul_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_mul_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil -} +func (lhs *tfheCiphertext) scalarMul(rhs uint64) (*tfheCiphertext, error) { + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_mul_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_mul_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_mul_fhe_uint32(lhs, rhs, sks) + }) +} func (lhs *tfheCiphertext) bitand(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot bitwise AND on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.bitand_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.bitand_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.bitand_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitand_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitand_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitand_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) bitor(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot bitwise OR on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.bitor_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.bitor_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.bitor_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitor_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitor_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitor_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) bitxor(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot bitwise XOR on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.bitxor_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.bitxor_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.bitxor_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitxor_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitxor_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.bitxor_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) shl(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot shl on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.shl_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.shl_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.shl_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.shl_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.shl_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.shl_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarShl(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar shl on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_shl_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_shl_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_shl_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_shl_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_shl_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_shl_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) shr(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot shr on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.shr_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.shr_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.shr_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.shr_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.shr_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.shr_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarShr(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar shr on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_shr_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_shr_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_shr_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_shr_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_shr_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_shr_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) eq(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot eq on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.eq_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.eq_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.eq_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.eq_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.eq_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.eq_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarEq(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar eq on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_eq_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_eq_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_eq_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_eq_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_eq_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_eq_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) ne(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot ne on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.ne_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.ne_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.ne_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.ne_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.ne_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.ne_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarNe(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar ne on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_ne_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_ne_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_ne_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_ne_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_ne_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_ne_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) ge(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot ge on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.ge_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.ge_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.ge_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.ge_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.ge_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.ge_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarGe(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar ge on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_ge_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_ge_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_ge_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_ge_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_ge_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_ge_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) gt(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot gt on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.gt_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.gt_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.gt_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.gt_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.gt_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.gt_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarGt(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar gt on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_gt_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_gt_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_gt_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_gt_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_gt_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_gt_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) le(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot le on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.le_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.le_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.le_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.le_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.le_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.le_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarLe(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar le on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_le_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_le_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_le_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_le_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_le_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_le_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot lt on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.lt_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.lt_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.lt_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.lt_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.lt_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.lt_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarLt(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar lt on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_lt_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_lt_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_lt_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_lt_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_lt_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_lt_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) min(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot min on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.min_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.min_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.min_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.min_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.min_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.min_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarMin(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar min on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_min_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_min_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_min_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_min_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_min_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_min_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) max(rhs *tfheCiphertext) (*tfheCiphertext, error) { - if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot max on a non-initialized ciphertext") - } - - if lhs.fheUintType != rhs.fheUintType { - return nil, errors.New("binary operations are only well-defined for identical types") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.max_fhe_uint8(lhs.ptr, rhs.ptr, sks)) - case FheUint16: - res.setPtr(C.max_fhe_uint16(lhs.ptr, rhs.ptr, sks)) - case FheUint32: - res.setPtr(C.max_fhe_uint32(lhs.ptr, rhs.ptr, sks)) - } - return res, nil + return lhs.executeBinaryCiphertextOperation(rhs, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.max_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.max_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.max_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) scalarMax(rhs uint64) (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot scalar max on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - pt := C.uint8_t(rhs) - res.setPtr(C.scalar_max_fhe_uint8(lhs.ptr, pt, sks)) - case FheUint16: - pt := C.uint16_t(rhs) - res.setPtr(C.scalar_max_fhe_uint16(lhs.ptr, pt, sks)) - case FheUint32: - pt := C.uint32_t(rhs) - res.setPtr(C.scalar_max_fhe_uint32(lhs.ptr, pt, sks)) - } - return res, nil + return lhs.executeBinaryScalarOperation(rhs, + func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { + return C.scalar_max_fhe_uint8(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { + return C.scalar_max_fhe_uint16(lhs, rhs, sks) + }, + func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { + return C.scalar_max_fhe_uint32(lhs, rhs, sks) + }) } func (lhs *tfheCiphertext) neg() (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot neg on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.neg_fhe_uint8(lhs.ptr, sks)) - case FheUint16: - res.setPtr(C.neg_fhe_uint16(lhs.ptr, sks)) - case FheUint32: - res.setPtr(C.neg_fhe_uint32(lhs.ptr, sks)) - } - return res, nil + return lhs.executeUnaryCiphertextOperation(lhs, + func(lhs unsafe.Pointer) unsafe.Pointer { + return C.neg_fhe_uint8(lhs, sks) + }, + func(lhs unsafe.Pointer) unsafe.Pointer { + return C.neg_fhe_uint16(lhs, sks) + }, + func(lhs unsafe.Pointer) unsafe.Pointer { + return C.neg_fhe_uint32(lhs, sks) + }) } func (lhs *tfheCiphertext) not() (*tfheCiphertext, error) { - if !lhs.availableForOps() { - panic("cannot not on a non-initialized ciphertext") - } - - res := new(tfheCiphertext) - res.fheUintType = lhs.fheUintType - switch lhs.fheUintType { - case FheUint8: - res.setPtr(C.not_fhe_uint8(lhs.ptr, sks)) - case FheUint16: - res.setPtr(C.not_fhe_uint16(lhs.ptr, sks)) - case FheUint32: - res.setPtr(C.not_fhe_uint32(lhs.ptr, sks)) - } - return res, nil + return lhs.executeUnaryCiphertextOperation(lhs, + func(lhs unsafe.Pointer) unsafe.Pointer { + return C.not_fhe_uint8(lhs, sks) + }, + func(lhs unsafe.Pointer) unsafe.Pointer { + return C.not_fhe_uint16(lhs, sks) + }, + func(lhs unsafe.Pointer) unsafe.Pointer { + return C.not_fhe_uint32(lhs, sks) + }) } func (ct *tfheCiphertext) castTo(castToType fheUintType) (*tfheCiphertext, error) { - if !ct.availableForOps() { - panic("cannot cast a non-initialized ciphertext") - } - if ct.fheUintType == castToType { return nil, errors.New("casting to same type is not supported") } - if !castToType.isValid() { - return nil, errors.New("invalid type to cast to") - } - res := new(tfheCiphertext) res.fheUintType = castToType @@ -2293,95 +2271,179 @@ func (ct *tfheCiphertext) castTo(castToType fheUintType) (*tfheCiphertext, error case FheUint8: switch castToType { case FheUint16: - res.setPtr(C.cast_8_16(ct.ptr, sks)) + from_ptr := C.deserialize_fhe_uint8(toBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint8 ciphertext") + } + to_ptr := C.cast_8_16(from_ptr, sks) + C.destroy_fhe_uint8(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint8 to FheUint16") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint16(to_ptr) + if err != nil { + return nil, err + } case FheUint32: - res.setPtr(C.cast_8_32(ct.ptr, sks)) + from_ptr := C.deserialize_fhe_uint8(toBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint8 ciphertext") + } + to_ptr := C.cast_8_32(from_ptr, sks) + C.destroy_fhe_uint8(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint8 to FheUint32") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint32(to_ptr) + if err != nil { + return nil, err + } + default: + panic("castTo: unexpected type to cast to") } case FheUint16: switch castToType { case FheUint8: - res.setPtr(C.cast_16_8(ct.ptr, sks)) + from_ptr := C.deserialize_fhe_uint16(toBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint16 ciphertext") + } + to_ptr := C.cast_16_8(from_ptr, sks) + C.destroy_fhe_uint16(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint16 to FheUint8") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint8(to_ptr) + if err != nil { + return nil, err + } case FheUint32: - res.setPtr(C.cast_16_32(ct.ptr, sks)) + from_ptr := C.deserialize_fhe_uint16(toBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint16 ciphertext") + } + to_ptr := C.cast_16_32(from_ptr, sks) + C.destroy_fhe_uint16(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint16 to FheUint32") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint32(to_ptr) + if err != nil { + return nil, err + } + default: + panic("castTo: unexpected type to cast to") } case FheUint32: switch castToType { case FheUint8: - res.setPtr(C.cast_32_8(ct.ptr, sks)) + from_ptr := C.deserialize_fhe_uint32(toBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint32 ciphertext") + } + to_ptr := C.cast_32_8(from_ptr, sks) + C.destroy_fhe_uint32(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint32 to FheUint8") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint8(to_ptr) + if err != nil { + return nil, err + } case FheUint16: - res.setPtr(C.cast_32_16(ct.ptr, sks)) + from_ptr := C.deserialize_fhe_uint32(toBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint32 ciphertext") + } + to_ptr := C.cast_32_16(from_ptr, sks) + C.destroy_fhe_uint32(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint32 to FheUint16") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint16(to_ptr) + if err != nil { + return nil, err + } + default: + panic("castTo: unexpected type to cast to") } } - + res.computeHash() return res, nil } -func (ct *tfheCiphertext) decrypt() big.Int { - if !ct.availableForOps() { - panic("cannot decrypt a null ciphertext") - } else if ct.value != nil { - return *ct.value - } +func (ct *tfheCiphertext) decrypt() (big.Int, error) { var value uint64 + var ret C.int switch ct.fheUintType { case FheUint8: - value = uint64(C.decrypt_fhe_uint8(cks, ct.ptr)) + ptr := C.deserialize_fhe_uint8(toBufferView(ct.serialization)) + if ptr == nil { + return *new(big.Int).SetUint64(0), errors.New("failed to deserialize FheUint8") + } + var result C.uint8_t + ret = C.decrypt_fhe_uint8(cks, ptr, &result) + C.destroy_fhe_uint8(ptr) + value = uint64(result) case FheUint16: - value = uint64(C.decrypt_fhe_uint16(cks, ct.ptr)) + ptr := C.deserialize_fhe_uint16(toBufferView(ct.serialization)) + if ptr == nil { + return *new(big.Int).SetUint64(0), errors.New("failed to deserialize FheUint16") + } + var result C.uint16_t + ret = C.decrypt_fhe_uint16(cks, ptr, &result) + C.destroy_fhe_uint16(ptr) + value = uint64(result) case FheUint32: - value = uint64(C.decrypt_fhe_uint32(cks, ct.ptr)) + ptr := C.deserialize_fhe_uint32(toBufferView(ct.serialization)) + if ptr == nil { + return *new(big.Int).SetUint64(0), errors.New("failed to deserialize FheUint32") + } + var result C.uint32_t + ret = C.decrypt_fhe_uint32(cks, ptr, &result) + C.destroy_fhe_uint32(ptr) + value = uint64(result) + default: + panic("decrypt: unexpected ciphertext type") + } + if ret != 0 { + return *new(big.Int).SetUint64(0), errors.New("decrypt failed") } - ct.value = new(big.Int).SetUint64(value) - return *ct.value + return *new(big.Int).SetUint64(value), nil } -func (ct *tfheCiphertext) setPtr(ptr unsafe.Pointer) { - if ptr == nil { - panic("setPtr called with nil") - } - ct.ptr = ptr - atomic.AddUint64(&allocatedCiphertexts, 1) - switch ct.fheUintType { - case FheUint8: - runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { - C.destroy_fhe_uint8(ct.ptr) - }) - case FheUint16: - runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { - C.destroy_fhe_uint16(ct.ptr) - }) - case FheUint32: - runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { - C.destroy_fhe_uint32(ct.ptr) - }) - } +func (ct *tfheCiphertext) computeHash() { + hash := common.BytesToHash(crypto.Keccak256(ct.serialization)) + ct.hash = &hash } func (ct *tfheCiphertext) getHash() common.Hash { if ct.hash != nil { return *ct.hash } - if !ct.initialized() { - panic("cannot get hash of non-initialized ciphertext") - } - hash := common.BytesToHash(crypto.Keccak256(ct.serialize())) - ct.hash = &hash + ct.computeHash() return *ct.hash } -func (ct *tfheCiphertext) availableForOps() bool { - return (ct.initialized() && ct.ptr != nil) -} - -func (ct *tfheCiphertext) initialized() bool { - return (ct.ptr != nil) -} - -func (t *fheUintType) isValid() bool { - return (*t <= 2) +func isValidType(t byte) bool { + if uint8(t) < uint8(FheUint8) || uint8(t) > uint8(FheUint32) { + return false + } + return true } -// Used for testing. func encryptAndSerializeCompact(value uint32, fheUintType fheUintType) []byte { out := &C.Buffer{} switch fheUintType { diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index 3231d5e9692d..74ff04be352b 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -38,8 +38,8 @@ func TfheEncryptDecrypt(t *testing.T, fheUintType fheUintType) { } ct := new(tfheCiphertext) ct.encrypt(val, fheUintType) - res := ct.decrypt() - if res.Uint64() != val.Uint64() { + res, err := ct.decrypt() + if err != nil || res.Uint64() != val.Uint64() { t.Fatalf("%d != %d", val.Uint64(), res.Uint64()) } } @@ -56,8 +56,8 @@ func TfheTrivialEncryptDecrypt(t *testing.T, fheUintType fheUintType) { } ct := new(tfheCiphertext) ct.trivialEncrypt(val, fheUintType) - res := ct.decrypt() - if res.Uint64() != val.Uint64() { + res, err := ct.decrypt() + if err != nil || res.Uint64() != val.Uint64() { t.Fatalf("%d != %d", val.Uint64(), res.Uint64()) } } @@ -116,8 +116,8 @@ func TfheSerializeDeserializeCompact(t *testing.T, fheUintType fheUintType) { t.Fatalf("serialization is non-deterministic") } - decrypted := ct2.decrypt() - if uint32(decrypted.Uint64()) != val { + decrypted, err := ct2.decrypt() + if err != nil || uint32(decrypted.Uint64()) != val { t.Fatalf("decrypted value is incorrect") } } @@ -172,8 +172,8 @@ func TfheDeserializeCompact(t *testing.T, fheUintType fheUintType) { if err != nil { t.Fatalf("compact deserialization failed") } - decryptedVal := ct.decrypt() - if uint32(decryptedVal.Uint64()) != val { + decryptedVal, err := ct.decrypt() + if err != nil || uint32(decryptedVal.Uint64()) != val { t.Fatalf("compact deserialization wrong decryption") } } @@ -205,8 +205,8 @@ func TfheAdd(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.add(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -228,8 +228,8 @@ func TfheScalarAdd(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarAdd(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -253,8 +253,8 @@ func TfheSub(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.sub(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -276,8 +276,8 @@ func TfheScalarSub(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarSub(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -301,8 +301,8 @@ func TfheMul(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.mul(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -324,8 +324,8 @@ func TfheScalarMul(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarMul(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -349,8 +349,8 @@ func TfheBitAnd(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.bitand(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -374,8 +374,8 @@ func TfheBitOr(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.bitor(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -399,8 +399,8 @@ func TfheBitXor(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.bitxor(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -424,8 +424,8 @@ func TfheShl(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.shl(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -447,8 +447,8 @@ func TfheScalarShl(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarShl(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -472,8 +472,8 @@ func TfheShr(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.shr(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -495,8 +495,8 @@ func TfheScalarShr(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarShr(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected.Uint64() { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected.Uint64() { t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) } } @@ -526,8 +526,8 @@ func TfheEq(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.eq(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -555,8 +555,8 @@ func TfheScalarEq(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarEq(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -586,8 +586,8 @@ func TfheNe(t *testing.T, fheUintType fheUintType) { ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) ctRes, _ := ctA.ne(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -615,8 +615,8 @@ func TfheScalarNe(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes, _ := ctA.scalarNe(b.Uint64()) - res := ctRes.decrypt() - if res.Uint64() != expected { + res, err := ctRes.decrypt() + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", expected, res.Uint64()) } } @@ -640,12 +640,12 @@ func TfheGe(t *testing.T, fheUintType fheUintType) { ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.ge(ctB) ctRes2, _ := ctB.ge(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 1 { - t.Fatalf("%d != %d", 0, res1.Uint64()) + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != 1 { + t.Fatalf("%d != %d", 1, res1.Uint64()) } - if res2.Uint64() != 0 { + if err2 != nil || res2.Uint64() != 0 { t.Fatalf("%d != %d", 0, res2.Uint64()) } } @@ -666,8 +666,8 @@ func TfheScalarGe(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.scalarGe(b.Uint64()) - res1 := ctRes1.decrypt() - if res1.Uint64() != 1 { + res1, err := ctRes1.decrypt() + if err != nil || res1.Uint64() != 1 { t.Fatalf("%d != %d", 0, res1.Uint64()) } } @@ -691,12 +691,12 @@ func TfheGt(t *testing.T, fheUintType fheUintType) { ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.gt(ctB) ctRes2, _ := ctB.gt(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 1 { + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != 1 { t.Fatalf("%d != %d", 0, res1.Uint64()) } - if res2.Uint64() != 0 { + if err2 != nil || res2.Uint64() != 0 { t.Fatalf("%d != %d", 0, res2.Uint64()) } } @@ -717,8 +717,8 @@ func TfheScalarGt(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.scalarGt(b.Uint64()) - res1 := ctRes1.decrypt() - if res1.Uint64() != 1 { + res1, err := ctRes1.decrypt() + if err != nil || res1.Uint64() != 1 { t.Fatalf("%d != %d", 0, res1.Uint64()) } } @@ -742,12 +742,12 @@ func TfheLe(t *testing.T, fheUintType fheUintType) { ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.le(ctB) ctRes2, _ := ctB.le(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 0 { + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != 0 { t.Fatalf("%d != %d", 0, res1.Uint64()) } - if res2.Uint64() != 1 { + if err2 != nil || res2.Uint64() != 1 { t.Fatalf("%d != %d", 0, res2.Uint64()) } } @@ -768,8 +768,8 @@ func TfheScalarLe(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.scalarLe(b.Uint64()) - res1 := ctRes1.decrypt() - if res1.Uint64() != 0 { + res1, err := ctRes1.decrypt() + if err != nil || res1.Uint64() != 0 { t.Fatalf("%d != %d", 0, res1.Uint64()) } } @@ -793,12 +793,12 @@ func TfheLt(t *testing.T, fheUintType fheUintType) { ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.lt(ctB) ctRes2, _ := ctB.lt(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 0 { + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != 0 { t.Fatalf("%d != %d", 0, res1.Uint64()) } - if res2.Uint64() != 1 { + if err2 != nil || res2.Uint64() != 1 { t.Fatalf("%d != %d", 0, res2.Uint64()) } } @@ -819,8 +819,8 @@ func TfheScalarLt(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.scalarLt(b.Uint64()) - res1 := ctRes1.decrypt() - if res1.Uint64() != 0 { + res1, err := ctRes1.decrypt() + if err != nil || res1.Uint64() != 0 { t.Fatalf("%d != %d", 0, res1.Uint64()) } } @@ -844,12 +844,12 @@ func TfheMin(t *testing.T, fheUintType fheUintType) { ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.min(ctB) ctRes2, _ := ctB.min(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != b.Uint64() { + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != b.Uint64() { t.Fatalf("%d != %d", b.Uint64(), res1.Uint64()) } - if res2.Uint64() != b.Uint64() { + if err2 != nil || res2.Uint64() != b.Uint64() { t.Fatalf("%d != %d", b.Uint64(), res2.Uint64()) } } @@ -870,8 +870,8 @@ func TfheScalarMin(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.scalarMin(b.Uint64()) - res1 := ctRes1.decrypt() - if res1.Uint64() != b.Uint64() { + res1, err1 := ctRes1.decrypt() + if err1 != nil || res1.Uint64() != b.Uint64() { t.Fatalf("%d != %d", 0, res1.Uint64()) } } @@ -895,12 +895,12 @@ func TfheMax(t *testing.T, fheUintType fheUintType) { ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.max(ctB) ctRes2, _ := ctB.max(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != a.Uint64() { + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != a.Uint64() { t.Fatalf("%d != %d", b.Uint64(), res1.Uint64()) } - if res2.Uint64() != a.Uint64() { + if err2 != nil || res2.Uint64() != a.Uint64() { t.Fatalf("%d != %d", b.Uint64(), res2.Uint64()) } } @@ -921,8 +921,8 @@ func TfheScalarMax(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.scalarMax(b.Uint64()) - res1 := ctRes1.decrypt() - if res1.Uint64() != a.Uint64() { + res1, err1 := ctRes1.decrypt() + if err1 != nil || res1.Uint64() != a.Uint64() { t.Fatalf("%d != %d", 0, res1.Uint64()) } } @@ -945,8 +945,8 @@ func TfheNeg(t *testing.T, fheUintType fheUintType) { ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.neg() - res1 := ctRes1.decrypt() - if res1.Uint64() != expected { + res1, err1 := ctRes1.decrypt() + if err1 != nil || res1.Uint64() != expected { t.Fatalf("%d != %d", res1.Uint64(), expected) } } @@ -969,8 +969,8 @@ func TfheNot(t *testing.T, fheUintType fheUintType) { ctA.encrypt(a, fheUintType) ctRes1, _ := ctA.not() - res1 := ctRes1.decrypt() - if res1.Uint64() != expected { + res1, err1 := ctRes1.decrypt() + if err1 != nil || res1.Uint64() != expected { t.Fatalf("%d != %d", res1.Uint64(), expected) } } @@ -1006,9 +1006,9 @@ func TfheCast(t *testing.T, fheUintTypeFrom fheUintType, fheUintTypeTo fheUintTy if ctRes.fheUintType != fheUintTypeTo { t.Fatalf("type %d != type %d", ctA.fheUintType, fheUintTypeTo) } - res := ctRes.decrypt() + res, err := ctRes.decrypt() expected := a.Uint64() % modulus - if res.Uint64() != expected { + if err != nil || res.Uint64() != expected { t.Fatalf("%d != %d", res.Uint64(), expected) } }