diff --git a/std/hash/sha2/sha2.go b/std/hash/sha2/sha2.go index edb261e621..ea36f7f70d 100644 --- a/std/hash/sha2/sha2.go +++ b/std/hash/sha2/sha2.go @@ -6,9 +6,13 @@ package sha2 import ( "encoding/binary" + "math/big" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/bitslice" + "github.com/consensys/gnark/std/math/cmp" "github.com/consensys/gnark/std/math/uints" "github.com/consensys/gnark/std/permutation/sha2" ) @@ -18,16 +22,17 @@ var _seed = uints.NewU32Array([]uint32{ }) type digest struct { + api frontend.API uapi *uints.BinaryField[uints.U32] in []uints.U8 } -func New(api frontend.API) (hash.BinaryHasher, error) { +func New(api frontend.API) (hash.BinaryFixedLengthHasher, error) { uapi, err := uints.New[uints.U32](api) if err != nil { return nil, err } - return &digest{uapi: uapi}, nil + return &digest{api: api, uapi: uapi}, nil } func (d *digest) Write(data []uints.U8) { @@ -71,7 +76,6 @@ func (d *digest) Sum() []uints.U8 { } func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { - panic("TODO") // we need to do two things here -- first the padding has to be put to the // right place. For that we need to know how many blocks we have used. We // need to fit at least 9 more bytes (padding byte and 8 bytes for input @@ -80,6 +84,69 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { // // idea - have a mask for blocks where 1 is only for the block we want to // use. + + data := make([]uints.U8, len(d.in)) + copy(data, d.in) + + comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(data)+64+8)), false) + + for i := 0; i < 64+8; i++ { + data = append(data, uints.NewU8(0)) + } + + lenMod64 := d.mod64(length) + lenMod64Less56 := comparator.IsLess(lenMod64, 56) + + paddingCount := d.api.Sub(64, lenMod64) + paddingCount = d.api.Select(lenMod64Less56, paddingCount, d.api.Add(paddingCount, 64)) + + totalLen := d.api.Add(length, paddingCount) + last8BytesPos := d.api.Sub(totalLen, 8) + + var dataLenBtyes [8]frontend.Variable + d.bigEndianPutUint64(dataLenBtyes[:], d.api.Mul(length, 8)) + + for i := range data { + isPaddingStartPos := d.api.IsZero(d.api.Sub(i, length)) + data[i].Val = d.api.Select(isPaddingStartPos, 0x80, data[i].Val) + + isPaddingPos := comparator.IsLess(length, i) + data[i].Val = d.api.Select(isPaddingPos, 0, data[i].Val) + } + + for i := range data { + isLast8BytesPos := d.api.IsZero(d.api.Sub(i, last8BytesPos)) + for j := 0; j < 8; j++ { + if i+j < len(data) { + data[i+j].Val = d.api.Select(isLast8BytesPos, dataLenBtyes[j], data[i+j].Val) + } + } + } + + var runningDigest [8]uints.U32 + var resultDigest [8]uints.U32 + var buf [64]uints.U8 + copy(runningDigest[:], _seed) + copy(resultDigest[:], _seed) + + for i := 0; i < len(data)/64; i++ { + copy(buf[:], data[i*64:(i+1)*64]) + runningDigest = sha2.Permute(d.uapi, runningDigest, buf) + + isInRange := comparator.IsLess(i*64, totalLen) + + for j := 0; j < 8; j++ { + for k := 0; k < 4; k++ { + resultDigest[j][k].Val = d.api.Select(isInRange, runningDigest[j][k].Val, resultDigest[j][k].Val) + } + } + } + + var ret []uints.U8 + for i := range resultDigest { + ret = append(ret, d.uapi.UnpackMSB(resultDigest[i])...) + } + return ret } func (d *digest) Reset() { @@ -87,3 +154,15 @@ func (d *digest) Reset() { } func (d *digest) Size() int { return 32 } + +func (d *digest) mod64(v frontend.Variable) frontend.Variable { + lower, _ := bitslice.Partition(d.api, v, 6, bitslice.WithNbDigits(64)) + return lower +} + +func (d *digest) bigEndianPutUint64(b []frontend.Variable, x frontend.Variable) { + bts := bits.ToBinary(d.api, x, bits.WithNbDigits(64)) + for i := 0; i < 8; i++ { + b[i] = bits.FromBinary(d.api, bts[(8-i-1)*8:(8-i)*8]) + } +} diff --git a/std/hash/sha2/sha2_test.go b/std/hash/sha2/sha2_test.go index d4acf5baf3..0093fddc43 100644 --- a/std/hash/sha2/sha2_test.go +++ b/std/hash/sha2/sha2_test.go @@ -48,3 +48,44 @@ func TestSHA2(t *testing.T) { t.Fatal(err) } } + +type sha2FixedLengthCircuit struct { + In []uints.U8 + Length frontend.Variable + Expected [32]uints.U8 +} + +func (c *sha2FixedLengthCircuit) Define(api frontend.API) error { + h, err := New(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + h.Write(c.In) + res := h.FixedLengthSum(c.Length) + if len(res) != 32 { + return fmt.Errorf("not 32 bytes") + } + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestSHA2FixedLengthSum(t *testing.T) { + bts := make([]byte, 144) + length := 56 + dgst := sha256.Sum256(bts[:length]) + witness := sha2FixedLengthCircuit{ + In: uints.NewU8Array(bts), + Length: length, + } + copy(witness.Expected[:], uints.NewU8Array(dgst[:])) + err := test.IsSolved(&sha2FixedLengthCircuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } +}