Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sha2 primitive #689

Merged
merged 13 commits into from
Jun 5, 2023
2 changes: 1 addition & 1 deletion internal/backend/circuits/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func init() {
},
}

addNewEntry("recursive_hint", &recursiveHint{}, good, bad, gnark.Curves(), make3, bits.NBits)
addNewEntry("recursive_hint", &recursiveHint{}, good, bad, gnark.Curves(), make3, bits.GetHints()[1])
}

{
Expand Down
6 changes: 3 additions & 3 deletions std/accumulator/merkle/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type MerkleProof struct {

// leafSum returns the hash created from data inserted to form a leaf.
// Without domain separation.
func leafSum(api frontend.API, h hash.Hash, data frontend.Variable) frontend.Variable {
func leafSum(api frontend.API, h hash.FieldHasher, data frontend.Variable) frontend.Variable {

h.Reset()
h.Write(data)
Expand All @@ -73,7 +73,7 @@ func leafSum(api frontend.API, h hash.Hash, data frontend.Variable) frontend.Var

// nodeSum returns the hash created from data inserted to form a leaf.
// Without domain separation.
func nodeSum(api frontend.API, h hash.Hash, a, b frontend.Variable) frontend.Variable {
func nodeSum(api frontend.API, h hash.FieldHasher, a, b frontend.Variable) frontend.Variable {

h.Reset()
h.Write(a, b)
Expand All @@ -86,7 +86,7 @@ func nodeSum(api frontend.API, h hash.Hash, a, b frontend.Variable) frontend.Var
// true if the first element of the proof set is a leaf of data in the Merkle
// root. False is returned if the proof set or Merkle root is nil, and if
// 'numLeaves' equals 0.
func (mp *MerkleProof) VerifyProof(api frontend.API, h hash.Hash, leaf frontend.Variable) {
func (mp *MerkleProof) VerifyProof(api frontend.API, h hash.FieldHasher, leaf frontend.Variable) {

depth := len(mp.Path) - 1
sum := leafSum(api, h, mp.Path[0])
Expand Down
4 changes: 2 additions & 2 deletions std/commitments/fri/fri.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type RadixTwoFri struct {

// hash function that is used for Fiat Shamir and for committing to
// the oracles.
h hash.Hash
h hash.FieldHasher

// nbSteps number of interactions between the prover and the verifier
nbSteps int
Expand All @@ -66,7 +66,7 @@ type RadixTwoFri struct {
// NewRadixTwoFri creates an FFT-like oracle proof of proximity.
// * h is the hash function that is used for the Merkle proofs
// * gen is the generator of the cyclic group of unity of size \rho * size
func NewRadixTwoFri(size uint64, h hash.Hash, gen big.Int) RadixTwoFri {
func NewRadixTwoFri(size uint64, h hash.FieldHasher, gen big.Int) RadixTwoFri {

var res RadixTwoFri

Expand Down
4 changes: 2 additions & 2 deletions std/fiat-shamir/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type Settings struct {
Transcript *Transcript
Prefix string
BaseChallenges []frontend.Variable
Hash hash.Hash
Hash hash.FieldHasher
}

func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings {
Expand All @@ -20,7 +20,7 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro
}
}

func WithHash(hash hash.Hash, baseChallenges ...frontend.Variable) Settings {
func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings {
return Settings{
BaseChallenges: baseChallenges,
Hash: hash,
Expand Down
4 changes: 2 additions & 2 deletions std/fiat-shamir/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var (
// Transcript handles the creation of challenges for Fiat Shamir.
type Transcript struct {
// hash function that is used.
h hash.Hash
h hash.FieldHasher

challenges map[string]challenge
previous *challenge
Expand All @@ -52,7 +52,7 @@ type challenge struct {
// NewTranscript returns a new transcript.
// h is the hash function that is used to compute the challenges.
// challenges are the name of the challenges. The order is important.
func NewTranscript(api frontend.API, h hash.Hash, challengesID ...string) Transcript {
func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID ...string) Transcript {
n := len(challengesID)
t := Transcript{
challenges: make(map[string]challenge, n),
Expand Down
10 changes: 5 additions & 5 deletions std/gkr/gkr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (c *GkrVerifierCircuit) Define(api frontend.API) error {
}
assignment := makeInOutAssignment(testCase.Circuit, c.Input, c.Output)

var hsh hash.Hash
var hsh hash.FieldHasher
if c.ToFail {
hsh = NewMessageCounter(api, 1, 1)
} else {
Expand Down Expand Up @@ -414,7 +414,7 @@ func SliceEqual[T comparable](expected, seen []T) bool {

type HashDescription map[string]interface{}

func HashFromDescription(api frontend.API, d HashDescription) (hash.Hash, error) {
func HashFromDescription(api frontend.API, d HashDescription) (hash.FieldHasher, error) {
if _type, ok := d["type"]; ok {
switch _type {
case "const":
Expand Down Expand Up @@ -456,13 +456,13 @@ func (m *MessageCounter) Reset() {
m.state = m.startState
}

func NewMessageCounter(api frontend.API, startState, step int) hash.Hash {
func NewMessageCounter(api frontend.API, startState, step int) hash.FieldHasher {
transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step), api: api}
return transcript
}

func NewMessageCounterGenerator(startState, step int) func(frontend.API) hash.Hash {
return func(api frontend.API) hash.Hash {
func NewMessageCounterGenerator(startState, step int) func(frontend.API) hash.FieldHasher {
return func(api frontend.API) hash.FieldHasher {
return NewMessageCounter(api, startState, step)
}
}
Expand Down
41 changes: 36 additions & 5 deletions std/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,47 @@ limitations under the License.
// Package hash provides an interface that hash functions (as gadget) should implement.
package hash

import "github.com/consensys/gnark/frontend"

type Hash interface {

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/uints"
)

// FieldHasher hashes inputs into a short digest. This interface mocks
// [BinaryHasher], but is more suitable in-circuit by assuming the inputs are
// scalar field elements and outputs digest as a field element. Such hash
// functions are for examle Poseidon, MiMC etc.
type FieldHasher interface {
// Sum computes the hash of the internal state of the hash function.
Sum() frontend.Variable

// Write populate the internal state of the hash function with data.
// Write populate the internal state of the hash function with data. The inputs are native field elements.
Write(data ...frontend.Variable)

// Reset empty the internal state and put the intermediate state to zero.
Reset()
}

// BinaryHasher hashes inputs into a short digest. It takes as inputs bytes and
// outputs byte array whose length depends on the underlying hash function. For
// SNARK-native hash functions use [FieldHasher].
type BinaryHasher interface {
// Sum finalises the current hash and returns the digest.
Sum() []uints.U8

// Write writes more bytes into the current hash state.
Write([]uints.U8)

// Size returns the number of bytes this hash function returns in a call to
// [BinaryHasher.Sum].
Size() int
}

// BinaryFixedLengthHasher is like [BinaryHasher], but assumes the length of the
// input is not full length as defined during compile time. This allows to
// compute digest of variable-length input, unlike [BinaryHasher] which assumes
// the length of the input is the total number of bytes written.
type BinaryFixedLengthHasher interface {
BinaryHasher
// FixedLengthSum returns digest of the first length bytes.
FixedLengthSum(length frontend.Variable) []uints.U8
}
89 changes: 89 additions & 0 deletions std/hash/sha2/sha2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Package sha2 implements SHA2 hash computation.
//
// This package extends the SHA2 permutation function [sha2] into a full SHA2
// hash.
package sha2

import (
"encoding/binary"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/sha2"
)

var _seed = uints.NewU32Array([]uint32{
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
})

type digest struct {
uapi *uints.BinaryField[uints.U32]
in []uints.U8
}

func New(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U32](api)
if err != nil {
return nil, err
}
return &digest{uapi: uapi}, nil
}

func (d *digest) Write(data []uints.U8) {
d.in = append(d.in, data...)
}

func (d *digest) padded(bytesLen int) []uints.U8 {
zeroPadLen := 55 - bytesLen%64
if zeroPadLen < 0 {
zeroPadLen += 64
}
if cap(d.in) < len(d.in)+9+zeroPadLen {
// in case this is the first time this method is called increase the
// capacity of the slice to fit the padding.
d.in = append(d.in, make([]uints.U8, 9+zeroPadLen)...)
d.in = d.in[:len(d.in)-9-zeroPadLen]
}
buf := d.in
buf = append(buf, uints.NewU8(0x80))
buf = append(buf, uints.NewU8Array(make([]uint8, zeroPadLen))...)
lenbuf := make([]uint8, 8)
binary.BigEndian.PutUint64(lenbuf, uint64(8*bytesLen))
buf = append(buf, uints.NewU8Array(lenbuf)...)
return buf
}

func (d *digest) Sum() []uints.U8 {
var runningDigest [8]uints.U32
var buf [64]uints.U8
copy(runningDigest[:], _seed)
padded := d.padded(len(d.in))
for i := 0; i < len(padded)/64; i++ {
copy(buf[:], padded[i*64:(i+1)*64])
runningDigest = sha2.Permute(d.uapi, runningDigest, buf)
}
var ret []uints.U8
for i := range runningDigest {
ret = append(ret, d.uapi.UnpackMSB(runningDigest[i])...)
}
return ret
}

func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
panic("TODO")
// we need to do two things here -- first the padding has to be put to the
// right place. For that we need to know how many blocks we have used. We
// need to fit at least 9 more bytes (padding byte and 8 bytes for input
// length). Knowing the block, we have to keep running track if the current
// block is the expected one.
//
// idea - have a mask for blocks where 1 is only for the block we want to
// use.
}

func (d *digest) Reset() {
d.in = nil
}

func (d *digest) Size() int { return 32 }
50 changes: 50 additions & 0 deletions std/hash/sha2/sha2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package sha2

import (
"crypto/sha256"
"fmt"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/test"
)

type sha2Circuit struct {
In []uints.U8
Expected [32]uints.U8
}

func (c *sha2Circuit) Define(api frontend.API) error {
h, err := New(api)
if err != nil {
return err
}
uapi, err := uints.New[uints.U32](api)
if err != nil {
return err
}
h.Write(c.In)
res := h.Sum()
if len(res) != 32 {
return fmt.Errorf("not 32 bytes")
}
for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestSHA2(t *testing.T) {
bts := make([]byte, 310)
dgst := sha256.Sum256(bts)
witness := sha2Circuit{
In: uints.NewU8Array(bts),
}
copy(witness.Expected[:], uints.NewU8Array(dgst[:]))
err := test.IsSolved(&sha2Circuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
}
7 changes: 3 additions & 4 deletions std/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/consensys/gnark/std/evmprecompiles"
"github.com/consensys/gnark/std/internal/logderivarg"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/bitslice"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/gnark/std/selector"
Expand All @@ -30,13 +31,11 @@ func registerHints() {
solver.RegisterHint(sw_bls12377.DecomposeScalarG1)
solver.RegisterHint(sw_bls24315.DecomposeScalarG2)
solver.RegisterHint(sw_bls12377.DecomposeScalarG2)
solver.RegisterHint(bits.NTrits)
solver.RegisterHint(bits.NNAF)
solver.RegisterHint(bits.IthBit)
solver.RegisterHint(bits.NBits)
solver.RegisterHint(bits.GetHints()...)
solver.RegisterHint(selector.GetHints()...)
solver.RegisterHint(emulated.GetHints()...)
solver.RegisterHint(rangecheck.GetHints()...)
solver.RegisterHint(evmprecompiles.GetHints()...)
solver.RegisterHint(logderivarg.GetHints()...)
solver.RegisterHint(bitslice.GetHints()...)
}
Loading