Skip to content

Commit

Permalink
Merge pull request #924 from Consensys/feat/bypass-compression
Browse files Browse the repository at this point in the history
Feat/bypass compression
  • Loading branch information
Tabaie committed Nov 18, 2023
2 parents 5b37273 + 5c94c29 commit 9fb591e
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 191 deletions.
11 changes: 11 additions & 0 deletions std/compress/lzss/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Compressor struct {
type Level uint8

const (
NoCompression Level = 0
// BestCompression allows the compressor to produce a stream of bit-level granularity,
// giving the compressor this freedom helps it achieve better compression ratios but
// will impose a high number of constraints on the SNARK decompressor
Expand Down Expand Up @@ -79,6 +80,11 @@ func initBackRefTypes(dictLen int, level Level) (short, long, dict backrefType)
wordAlign := func(a int) uint8 {
return (uint8(a) + uint8(level) - 1) / uint8(level) * uint8(level)
}
if level == NoCompression {
wordAlign = func(a int) uint8 {
return uint8(a)
}
}
short = newBackRefType(symbolShort, wordAlign(14), 8, false)
long = newBackRefType(symbolLong, wordAlign(19), 8, false)
dict = newBackRefType(symbolDict, wordAlign(bits.Len(uint(dictLen))), 8, true)
Expand All @@ -94,6 +100,11 @@ func (compressor *Compressor) Compress(d []byte) (c []byte, err error) {

// reset output buffer
compressor.buf.Reset()
compressor.buf.WriteByte(byte(compressor.level))
if compressor.level == NoCompression {
compressor.buf.Write(d)
return compressor.buf.Bytes(), nil
}
compressor.bw = bitio.NewWriter(&compressor.buf)

// build the index
Expand Down
26 changes: 22 additions & 4 deletions std/compress/lzss/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func testCompressionRoundTrip(t *testing.T, d []byte) {
c, err := compressor.Compress(d)
require.NoError(t, err)

dBack, err := DecompressGo(c, getDictionary(), BestCompression)
dBack, err := DecompressGo(c, getDictionary())
require.NoError(t, err)

if !bytes.Equal(d, dBack) {
Expand All @@ -37,6 +37,24 @@ func TestNoCompression(t *testing.T) {
testCompressionRoundTrip(t, []byte{'h', 'i'})
}

func TestNoCompressionAttempt(t *testing.T) {

d := []byte{253, 254, 255}

compressor, err := NewCompressor(getDictionary(), NoCompression)
require.NoError(t, err)

c, err := compressor.Compress(d)
require.NoError(t, err)

dBack, err := DecompressGo(c, getDictionary())
require.NoError(t, err)

if !bytes.Equal(d, dBack) {
t.Fatal("round trip failed")
}
}

func Test9E(t *testing.T) {
testCompressionRoundTrip(t, []byte{1, 1, 1, 1, 2, 1, 1, 1, 1})
}
Expand Down Expand Up @@ -75,13 +93,13 @@ func FuzzCompress(f *testing.F) {
t.Fatal(err)
}

decompressedBytes, err := DecompressGo(compressedBytes, dict, level)
decompressedBytes, err := DecompressGo(compressedBytes, dict)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(input, decompressedBytes) {
t.Log("compression Mode:", level)
t.Log("compression level:", level)
t.Log("original bytes:", hex.EncodeToString(input))
t.Log("decompressed bytes:", hex.EncodeToString(decompressedBytes))
t.Log("dict", hex.EncodeToString(dict))
Expand Down Expand Up @@ -163,7 +181,7 @@ type compressResult struct {
}

func decompresslzss_v1(data, dict []byte) ([]byte, error) {
return DecompressGo(data, dict, BestCompression)
return DecompressGo(data, dict)
}

func compresslzss_v1(compressor *Compressor, data []byte) (compressResult, error) {
Expand Down
24 changes: 18 additions & 6 deletions std/compress/lzss/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ import (
"github.com/icza/bitio"
)

func DecompressGo(data, dict []byte, level Level) (d []byte, err error) {
func DecompressGo(data, dict []byte) (d []byte, err error) {
// d[i < 0] = Settings.BackRefSettings.Symbol by convention
var out bytes.Buffer
out.Grow(len(data)*6 + len(dict))
in := bitio.NewReader(bytes.NewReader(data))

level := Level(in.TryReadByte())
if level == NoCompression {
return data[1:], nil
}

dict = augmentDict(dict)
shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level)

Expand Down Expand Up @@ -54,18 +59,25 @@ func DecompressGo(data, dict []byte, level Level) (d []byte, err error) {
func ReadIntoStream(data, dict []byte, level Level) compress.Stream {
in := bitio.NewReader(bytes.NewReader(data))

wordLen := int(level)

dict = augmentDict(dict)
shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level)

wordLen := int(level)
bDict := backref{bType: dictBackRefType}
bShort := backref{bType: shortBackRefType}
bLong := backref{bType: longBackRefType}

levelFromData := Level(in.TryReadByte())
if levelFromData != NoCompression && levelFromData != level {
panic("compression mode mismatch")
}

out := compress.Stream{
NbSymbs: 1 << wordLen,
}

bDict := backref{bType: dictBackRefType}
bShort := backref{bType: shortBackRefType}
bLong := backref{bType: longBackRefType}
out.WriteNum(int(levelFromData), 8/wordLen)

s := in.TryReadByte()

Expand All @@ -84,7 +96,7 @@ func ReadIntoStream(data, dict []byte, level Level) compress.Stream {
// dict back ref
b = &bDict
}
if b != nil {
if b != nil && levelFromData != NoCompression {
b.readFrom(in)
address := b.address
if b != &bDict {
Expand Down
Binary file added std/compress/lzss/dict_naive
Binary file not shown.
32 changes: 21 additions & 11 deletions std/compress/lzss/snark.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab
dictBrNbWords := int(dictBackRefType.nbBitsBackRef) / wordLen
byteNbWords := 8 / wordLen

fileCompressionMode := readNum(api, c, byteNbWords, wordLen)
c = c[byteNbWords:]
cLength = api.Sub(cLength, byteNbWords)
api.AssertIsEqual(api.Mul(fileCompressionMode, fileCompressionMode), api.Mul(fileCompressionMode, wordLen)) // if fcm!=0, then fcm=wordLen
decompressionNotBypassed := api.Sub(1, api.IsZero(fileCompressionMode))

// assert that c are within range
cRangeTable := logderivlookup.New(api)
for i := 0; i < 1<<wordLen; i++ {
Expand Down Expand Up @@ -48,9 +54,10 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab

curr := bytesTable.Lookup(inI)[0]

currIndicatesLongBr := api.IsZero(api.Sub(curr, symbolLong))
currIndicatesShortBr := api.IsZero(api.Sub(curr, symbolShort))
currIndicatesDr := api.IsZero(api.Sub(curr, symbolDict))
currMinusLong := api.Sub(api.Mul(curr, decompressionNotBypassed), symbolLong) // if bypassing decompression, currIndicatesXX = 0
currIndicatesLongBr := api.IsZero(currMinusLong)
currIndicatesShortBr := api.IsZero(api.Sub(currMinusLong, symbolShort-symbolLong))
currIndicatesDr := api.IsZero(api.Sub(currMinusLong, symbolDict-symbolLong))
currIndicatesBr := api.Add(currIndicatesLongBr, currIndicatesShortBr)
currIndicatesCp := api.Add(currIndicatesBr, currIndicatesDr)

Expand Down Expand Up @@ -159,14 +166,7 @@ type numReader struct {
func newNumReader(api frontend.API, c []frontend.Variable, numNbBits, wordNbBits int) *numReader {
nbWords := numNbBits / wordNbBits
stepCoeff := 1 << wordNbBits
nxt := frontend.Variable(0)
coeff := frontend.Variable(1)
if len(c) >= nbWords {
for i := 0; i < nbWords; i++ {
nxt = api.MulAcc(nxt, coeff, c[i])
coeff = api.Mul(coeff, stepCoeff)
}
}
nxt := readNum(api, c, nbWords, stepCoeff)
return &numReader{
api: api,
c: c,
Expand All @@ -176,6 +176,16 @@ func newNumReader(api frontend.API, c []frontend.Variable, numNbBits, wordNbBits
}
}

func readNum(api frontend.API, c []frontend.Variable, nbWords, stepCoeff int) frontend.Variable {
res := frontend.Variable(0)
coeff := frontend.Variable(1)
for i := 0; i < nbWords && i < len(c); i++ {
res = api.MulAcc(res, coeff, c[i])
coeff = api.Mul(coeff, stepCoeff)
}
return res
}

// next returns the next number in the sequence. returns 0 upon EOF
func (nr *numReader) next() frontend.Variable {
res := nr.nxt
Expand Down
Loading

0 comments on commit 9fb591e

Please sign in to comment.