Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: implement FixedLengthSum of sha2 #821

Merged
merged 5 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions std/hash/sha2/sha2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ package sha2

import (
"encoding/binary"
"math/big"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/bitslice"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/sha2"
)
Expand All @@ -18,16 +22,17 @@ var _seed = uints.NewU32Array([]uint32{
})

type digest struct {
api frontend.API
uapi *uints.BinaryField[uints.U32]
in []uints.U8
}

func New(api frontend.API) (hash.BinaryHasher, error) {
func New(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U32](api)
if err != nil {
return nil, err
}
return &digest{uapi: uapi}, nil
return &digest{api: api, uapi: uapi}, nil
}

func (d *digest) Write(data []uints.U8) {
Expand Down Expand Up @@ -71,7 +76,6 @@ func (d *digest) Sum() []uints.U8 {
}

func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
panic("TODO")
// we need to do two things here -- first the padding has to be put to the
// right place. For that we need to know how many blocks we have used. We
// need to fit at least 9 more bytes (padding byte and 8 bytes for input
Expand All @@ -80,10 +84,85 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
//
// idea - have a mask for blocks where 1 is only for the block we want to
// use.

data := make([]uints.U8, len(d.in))
copy(data, d.in)

comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(data)+64+8)), false)

for i := 0; i < 64+8; i++ {
data = append(data, uints.NewU8(0))
}

lenMod64 := d.mod64(length)
lenMod64Less56 := comparator.IsLess(lenMod64, 56)

paddingCount := d.api.Sub(64, lenMod64)
paddingCount = d.api.Select(lenMod64Less56, paddingCount, d.api.Add(paddingCount, 64))

totalLen := d.api.Add(length, paddingCount)
last8BytesPos := d.api.Sub(totalLen, 8)

var dataLenBtyes [8]frontend.Variable
d.bigEndianPutUint64(dataLenBtyes[:], d.api.Mul(length, 8))

for i := range data {
isPaddingStartPos := d.api.IsZero(d.api.Sub(i, length))
data[i].Val = d.api.Select(isPaddingStartPos, 0x80, data[i].Val)

isPaddingPos := comparator.IsLess(length, i)
data[i].Val = d.api.Select(isPaddingPos, 0, data[i].Val)
}

for i := range data {
isLast8BytesPos := d.api.IsZero(d.api.Sub(i, last8BytesPos))
for j := 0; j < 8; j++ {
if i+j < len(data) {
data[i+j].Val = d.api.Select(isLast8BytesPos, dataLenBtyes[j], data[i+j].Val)
}
}
}

var runningDigest [8]uints.U32
var resultDigest [8]uints.U32
var buf [64]uints.U8
copy(runningDigest[:], _seed)
copy(resultDigest[:], _seed)

for i := 0; i < len(data)/64; i++ {
copy(buf[:], data[i*64:(i+1)*64])
runningDigest = sha2.Permute(d.uapi, runningDigest, buf)

isInRange := comparator.IsLess(i*64, totalLen)

for j := 0; j < 8; j++ {
for k := 0; k < 4; k++ {
resultDigest[j][k].Val = d.api.Select(isInRange, runningDigest[j][k].Val, resultDigest[j][k].Val)
}
}
}

var ret []uints.U8
for i := range resultDigest {
ret = append(ret, d.uapi.UnpackMSB(resultDigest[i])...)
}
return ret
}

func (d *digest) Reset() {
d.in = nil
}

func (d *digest) Size() int { return 32 }

func (d *digest) mod64(v frontend.Variable) frontend.Variable {
lower, _ := bitslice.Partition(d.api, v, 6, bitslice.WithNbDigits(64))
return lower
}

func (d *digest) bigEndianPutUint64(b []frontend.Variable, x frontend.Variable) {
bts := bits.ToBinary(d.api, x, bits.WithNbDigits(64))
for i := 0; i < 8; i++ {
b[i] = bits.FromBinary(d.api, bts[(8-i-1)*8:(8-i)*8])
}
}
41 changes: 41 additions & 0 deletions std/hash/sha2/sha2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,44 @@ func TestSHA2(t *testing.T) {
t.Fatal(err)
}
}

type sha2FixedLengthCircuit struct {
In []uints.U8
Length frontend.Variable
Expected [32]uints.U8
}

func (c *sha2FixedLengthCircuit) Define(api frontend.API) error {
h, err := New(api)
if err != nil {
return err
}
uapi, err := uints.New[uints.U32](api)
if err != nil {
return err
}
h.Write(c.In)
res := h.FixedLengthSum(c.Length)
if len(res) != 32 {
return fmt.Errorf("not 32 bytes")
}
for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestSHA2FixedLengthSum(t *testing.T) {
bts := make([]byte, 144)
length := 56
dgst := sha256.Sum256(bts[:length])
witness := sha2FixedLengthCircuit{
In: uints.NewU8Array(bts),
Length: length,
}
copy(witness.Expected[:], uints.NewU8Array(dgst[:]))
err := test.IsSolved(&sha2FixedLengthCircuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
}
Loading