Skip to content

Commit

Permalink
Merge pull request ethereum#60 from zama-ai/petar/ct-handle-in-call-a…
Browse files Browse the repository at this point in the history
…rg-depth-set-impl

Delegate handles found in call arguments
  • Loading branch information
dartdart26 authored Mar 16, 2023
2 parents da6a06e + 21b624e commit a3fc1a7
Show file tree
Hide file tree
Showing 5 changed files with 542 additions and 227 deletions.
169 changes: 90 additions & 79 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ func ActivePrecompiles(rules params.Rules) []common.Address {
// - the _remaining_ gas,
// - any error that occurred
func RunPrecompiledContract(p PrecompiledContract, accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, suppliedGas uint64, readOnly bool) (ret []byte, remainingGas uint64, err error) {
accessibleState.Interpreter().evm.depth++
defer func() { accessibleState.Interpreter().evm.depth-- }()
gasCost := p.RequiredGas(input)
if suppliedGas < gasCost {
return nil, 0, ErrOutOfGas
Expand Down Expand Up @@ -1220,24 +1222,56 @@ func init() {
}
}

func getVerifiedCiphertext(accessibleState PrecompileAccessibleState, ciphertextHash common.Hash) (*tfheCiphertext, bool) {
ct, ok := accessibleState.Interpreter().verifiedCiphertexts[ciphertextHash]
if ok && ct.depth <= accessibleState.Interpreter().evm.depth {
return ct.ciphertext, true
func isVerifiedAtCurrentDepth(interpreter *EVMInterpreter, ct *verifiedCiphertext) bool {
return ct.verifiedDepths.has(interpreter.evm.depth)
}

// Returns a pointer to the ciphertext if the given hash points to a verified ciphertext.
// Else, it returns nil.
func getVerifiedCiphertextFromEVM(interpreter *EVMInterpreter, ciphertextHash common.Hash) *verifiedCiphertext {
ct, ok := interpreter.verifiedCiphertexts[ciphertextHash]
if ok && isVerifiedAtCurrentDepth(interpreter, ct) {
return ct
}
return nil
}

// See getVerifiedCiphertextFromEVM().
func getVerifiedCiphertext(accessibleState PrecompileAccessibleState, ciphertextHash common.Hash) *verifiedCiphertext {
return getVerifiedCiphertextFromEVM(accessibleState.Interpreter(), ciphertextHash)
}

func importCiphertextToEVMAtDepth(interpreter *EVMInterpreter, ct *tfheCiphertext, depth int) *verifiedCiphertext {
existing, ok := interpreter.verifiedCiphertexts[ct.getHash()]
if ok {
existing.verifiedDepths.add(depth)
return existing
} else {
verifiedDepths := newDepthSet()
verifiedDepths.add(depth)
new := &verifiedCiphertext{
verifiedDepths,
ct,
}
interpreter.verifiedCiphertexts[ct.getHash()] = new
return new
}
return nil, false
}

func importCiphertextToEVM(interpreter *EVMInterpreter, ct *tfheCiphertext) *verifiedCiphertext {
return importCiphertextToEVMAtDepth(interpreter, ct, interpreter.evm.depth)
}

func importCiphertext(accessibleState PrecompileAccessibleState, ct *tfheCiphertext) *verifiedCiphertext {
return importCiphertextToEVM(accessibleState.Interpreter(), ct)
}

// Used when we want to skip FHE computation, e.g. gas estimation.
func importRandomCiphertext(accessibleState PrecompileAccessibleState) []byte {
ct := new(tfheCiphertext)
ct.makeRandom()
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: ct,
}
importCiphertext(accessibleState, ct)
ctHash := ct.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
return ctHash[:]
}

Expand All @@ -1253,12 +1287,12 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

a, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
b, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1267,20 +1301,16 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := a.add(b)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.add(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/add_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/add_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := verifiedCiphertext.ciphertext.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
ctHash := result.getHash()
return ctHash[:], nil
}

Expand Down Expand Up @@ -1362,7 +1392,7 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller
return nil, err
}
ctHash := ct.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = &verifiedCiphertext{accessibleState.Interpreter().evm.depth, ct}
importCiphertext(accessibleState, ct)
return ctHash.Bytes(), nil
}

Expand Down Expand Up @@ -1392,8 +1422,8 @@ func (e *reencrypt) Run(accessibleState PrecompileAccessibleState, caller common
if len(input) != 32 {
return nil, errors.New("invalid ciphertext handle")
}
ct, ok := accessibleState.Interpreter().verifiedCiphertexts[common.BytesToHash(input)]
if ok && ct.depth <= accessibleState.Interpreter().evm.depth {
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input))
if ct != nil {
decryptedValue := ct.ciphertext.decrypt()
reencryptedValue, err := fheEncryptToUserKey(decryptedValue, accessibleState.Interpreter().evm.Origin)
if err != nil {
Expand All @@ -1415,9 +1445,9 @@ func (e *delegateCiphertext) Run(accessibleState PrecompileAccessibleState, call
if len(input) != 32 {
return nil, errors.New("invalid ciphertext handle")
}
ct, ok := accessibleState.Interpreter().verifiedCiphertexts[common.BytesToHash(input)]
if ok {
ct.depth = minInt(ct.depth, accessibleState.Interpreter().evm.depth-1)
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input))
if ct != nil {
ct.verifiedDepths.add(accessibleState.Interpreter().evm.depth + 1)
return nil, nil
}
return nil, errors.New("unverified ciphertext handle")
Expand Down Expand Up @@ -1591,12 +1621,12 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1605,20 +1635,16 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.lte(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.lte(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/lte_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/lte_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand All @@ -1635,12 +1661,12 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1649,20 +1675,16 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.sub(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.sub(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/sub_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/sub_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand All @@ -1679,12 +1701,12 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1693,20 +1715,16 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.mul(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.mul(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/mul_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/mul_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand All @@ -1723,12 +1741,12 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1737,20 +1755,16 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.lt(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.lt(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/lt_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/lt_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand Down Expand Up @@ -1822,17 +1836,14 @@ func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.A
randInt := binary.BigEndian.Uint64(randBytes) % fheMessageModulus
randCt := new(tfheCiphertext)
randCt.trivialEncrypt(randInt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: randCt,
}
importCiphertext(accessibleState, randCt)

// TODO: for testing
err = os.WriteFile("/tmp/rand_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644)
if err != nil {
return nil, err
}
ctHash := randCt.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
return ctHash[:], nil
}

Expand All @@ -1843,6 +1854,6 @@ func (e *faucet) RequiredGas(input []byte) uint64 {
}

func (e *faucet) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
accessibleState.Interpreter().evm.StateDB.AddBalance(common.BytesToAddress(input[0:20]), big.NewInt(10000000000000000000))
accessibleState.Interpreter().evm.StateDB.AddBalance(common.BytesToAddress(input[0:20]), big.NewInt(1000000000000000000))
return input, nil
}
Loading

0 comments on commit a3fc1a7

Please sign in to comment.