Skip to content

Commit

Permalink
feat: implement FixedLengthSum of sha2
Browse files Browse the repository at this point in the history
  • Loading branch information
liyue201 committed Oct 12, 2023
1 parent 36b0b58 commit 0f5c541
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 4 deletions.
86 changes: 82 additions & 4 deletions std/hash/sha2/sha2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,30 @@ package sha2

import (
"encoding/binary"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/sha2"
"math/big"
)

var _seed = uints.NewU32Array([]uint32{
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
})

type digest struct {
api frontend.API
uapi *uints.BinaryField[uints.U32]
in []uints.U8
}

func New(api frontend.API) (hash.BinaryHasher, error) {
func New(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U32](api)
if err != nil {
return nil, err
}
return &digest{uapi: uapi}, nil
return &digest{api: api, uapi: uapi}, nil
}

func (d *digest) Write(data []uints.U8) {
Expand Down Expand Up @@ -71,7 +73,6 @@ func (d *digest) Sum() []uints.U8 {
}

func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
panic("TODO")
// we need to do two things here -- first the padding has to be put to the
// right place. For that we need to know how many blocks we have used. We
// need to fit at least 9 more bytes (padding byte and 8 bytes for input
Expand All @@ -80,10 +81,87 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
//
// idea - have a mask for blocks where 1 is only for the block we want to
// use.

api := d.api

data := make([]uints.U8, len(d.in))
copy(data, d.in)

comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(data)+64+8)), false)

for i := 0; i < 64+8; i++ {
data = append(data, uints.NewU8(0))
}

lenMod64 := mod64(api, length)
lenMod64Less56 := comparator.IsLess(lenMod64, 56)

paddingCount := api.Sub(64, lenMod64)
paddingCount = api.Select(lenMod64Less56, paddingCount, api.Add(paddingCount, 64))

totalLen := api.Add(length, paddingCount)
last8BytesPos := api.Sub(totalLen, 8)

var dataLenBtyes [8]frontend.Variable
bigEndianPutUint64(api, dataLenBtyes[:], api.Mul(length, 8))

for i := range data {
isPaddingStartPos := api.IsZero(api.Sub(i, length))
data[i].Val = api.Select(isPaddingStartPos, 0x80, data[i].Val)

isPaddingPos := comparator.IsLess(length, i)
data[i].Val = api.Select(isPaddingPos, 0, data[i].Val)
}

for i := range data {
isLast8BytesPos := api.IsZero(api.Sub(i, last8BytesPos))
for j := 0; j < 8; j++ {
if i+j < len(data) {
data[i+j].Val = api.Select(isLast8BytesPos, dataLenBtyes[j], data[i+j].Val)
}
}
}

var runningDigest [8]uints.U32
var resultDigest [8]uints.U32
var buf [64]uints.U8
copy(runningDigest[:], _seed)
copy(resultDigest[:], _seed)

for i := 0; i < len(data)/64; i++ {
copy(buf[:], data[i*64:(i+1)*64])
runningDigest = sha2.Permute(d.uapi, runningDigest, buf)

isInRange := comparator.IsLess(i*64, totalLen)

for j := 0; j < 8; j++ {
for k := 0; k < 4; k++ {
resultDigest[j][k].Val = api.Select(isInRange, runningDigest[j][k].Val, resultDigest[j][k].Val)
}
}
}

var ret []uints.U8
for i := range resultDigest {
ret = append(ret, d.uapi.UnpackMSB(resultDigest[i])...)
}
return ret
}

func (d *digest) Reset() {
d.in = nil
}

func (d *digest) Size() int { return 32 }

func mod64(api frontend.API, v frontend.Variable) frontend.Variable {
bits := api.ToBinary(v)
return api.FromBinary(bits[:6]...)
}

func bigEndianPutUint64(api frontend.API, b []frontend.Variable, x frontend.Variable) {
bits := api.ToBinary(x, 64)
for i := 0; i < 8; i++ {
b[i] = api.FromBinary(bits[(8-i-1)*8 : (8-i)*8]...)
}
}
41 changes: 41 additions & 0 deletions std/hash/sha2/sha2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,44 @@ func TestSHA2(t *testing.T) {
t.Fatal(err)
}
}

type sha2FixedLengthCircuit struct {
In []uints.U8
Length frontend.Variable
Expected [32]uints.U8
}

func (c *sha2FixedLengthCircuit) Define(api frontend.API) error {
h, err := New(api)
if err != nil {
return err
}
uapi, err := uints.New[uints.U32](api)
if err != nil {
return err
}
h.Write(c.In)
res := h.FixedLengthSum(c.Length)
if len(res) != 32 {
return fmt.Errorf("not 32 bytes")
}
for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestSHA2FixedLengthSum(t *testing.T) {
bts := make([]byte, 144)
length := 56
dgst := sha256.Sum256(bts[:length])
witness := sha2FixedLengthCircuit{
In: uints.NewU8Array(bts),
Length: length,
}
copy(witness.Expected[:], uints.NewU8Array(dgst[:]))
err := test.IsSolved(&sha2FixedLengthCircuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
}

0 comments on commit 0f5c541

Please sign in to comment.