Skip to content

Commit

Permalink
Feat: implement FixedLengthSum of sha2 (#821)
Browse files Browse the repository at this point in the history
* feat: implement FixedLengthSum of sha2

* chore: use bitslice.Partition

* chore: use bits package for binary decomp

* chore: use d.api directly

---------

Co-authored-by: Ivo Kubjas <ivo.kubjas@consensys.net>
  • Loading branch information
liyue201 and ivokub committed May 9, 2024
1 parent 971577c commit 78e19f6
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 3 deletions.
85 changes: 82 additions & 3 deletions std/hash/sha2/sha2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ package sha2

import (
"encoding/binary"
"math/big"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/bitslice"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/sha2"
)
Expand All @@ -18,16 +22,17 @@ var _seed = uints.NewU32Array([]uint32{
})

type digest struct {
api frontend.API
uapi *uints.BinaryField[uints.U32]
in []uints.U8
}

func New(api frontend.API) (hash.BinaryHasher, error) {
func New(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U32](api)
if err != nil {
return nil, err
}
return &digest{uapi: uapi}, nil
return &digest{api: api, uapi: uapi}, nil
}

func (d *digest) Write(data []uints.U8) {
Expand Down Expand Up @@ -71,7 +76,6 @@ func (d *digest) Sum() []uints.U8 {
}

func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
panic("TODO")
// we need to do two things here -- first the padding has to be put to the
// right place. For that we need to know how many blocks we have used. We
// need to fit at least 9 more bytes (padding byte and 8 bytes for input
Expand All @@ -80,10 +84,85 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
//
// idea - have a mask for blocks where 1 is only for the block we want to
// use.

data := make([]uints.U8, len(d.in))
copy(data, d.in)

comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(data)+64+8)), false)

for i := 0; i < 64+8; i++ {
data = append(data, uints.NewU8(0))
}

lenMod64 := d.mod64(length)
lenMod64Less56 := comparator.IsLess(lenMod64, 56)

paddingCount := d.api.Sub(64, lenMod64)
paddingCount = d.api.Select(lenMod64Less56, paddingCount, d.api.Add(paddingCount, 64))

totalLen := d.api.Add(length, paddingCount)
last8BytesPos := d.api.Sub(totalLen, 8)

var dataLenBtyes [8]frontend.Variable
d.bigEndianPutUint64(dataLenBtyes[:], d.api.Mul(length, 8))

for i := range data {
isPaddingStartPos := d.api.IsZero(d.api.Sub(i, length))
data[i].Val = d.api.Select(isPaddingStartPos, 0x80, data[i].Val)

isPaddingPos := comparator.IsLess(length, i)
data[i].Val = d.api.Select(isPaddingPos, 0, data[i].Val)
}

for i := range data {
isLast8BytesPos := d.api.IsZero(d.api.Sub(i, last8BytesPos))
for j := 0; j < 8; j++ {
if i+j < len(data) {
data[i+j].Val = d.api.Select(isLast8BytesPos, dataLenBtyes[j], data[i+j].Val)
}
}
}

var runningDigest [8]uints.U32
var resultDigest [8]uints.U32
var buf [64]uints.U8
copy(runningDigest[:], _seed)
copy(resultDigest[:], _seed)

for i := 0; i < len(data)/64; i++ {
copy(buf[:], data[i*64:(i+1)*64])
runningDigest = sha2.Permute(d.uapi, runningDigest, buf)

isInRange := comparator.IsLess(i*64, totalLen)

for j := 0; j < 8; j++ {
for k := 0; k < 4; k++ {
resultDigest[j][k].Val = d.api.Select(isInRange, runningDigest[j][k].Val, resultDigest[j][k].Val)
}
}
}

var ret []uints.U8
for i := range resultDigest {
ret = append(ret, d.uapi.UnpackMSB(resultDigest[i])...)
}
return ret
}

func (d *digest) Reset() {
d.in = nil
}

func (d *digest) Size() int { return 32 }

func (d *digest) mod64(v frontend.Variable) frontend.Variable {
lower, _ := bitslice.Partition(d.api, v, 6, bitslice.WithNbDigits(64))
return lower
}

func (d *digest) bigEndianPutUint64(b []frontend.Variable, x frontend.Variable) {
bts := bits.ToBinary(d.api, x, bits.WithNbDigits(64))
for i := 0; i < 8; i++ {
b[i] = bits.FromBinary(d.api, bts[(8-i-1)*8:(8-i)*8])
}
}
41 changes: 41 additions & 0 deletions std/hash/sha2/sha2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,44 @@ func TestSHA2(t *testing.T) {
t.Fatal(err)
}
}

type sha2FixedLengthCircuit struct {
In []uints.U8
Length frontend.Variable
Expected [32]uints.U8
}

func (c *sha2FixedLengthCircuit) Define(api frontend.API) error {
h, err := New(api)
if err != nil {
return err
}
uapi, err := uints.New[uints.U32](api)
if err != nil {
return err
}
h.Write(c.In)
res := h.FixedLengthSum(c.Length)
if len(res) != 32 {
return fmt.Errorf("not 32 bytes")
}
for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestSHA2FixedLengthSum(t *testing.T) {
bts := make([]byte, 144)
length := 56
dgst := sha256.Sum256(bts[:length])
witness := sha2FixedLengthCircuit{
In: uints.NewU8Array(bts),
Length: length,
}
copy(witness.Expected[:], uints.NewU8Array(dgst[:]))
err := test.IsSolved(&sha2FixedLengthCircuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
}

0 comments on commit 78e19f6

Please sign in to comment.