Skip to content

Commit 8abd495

Browse files
authored
inflate: Generate code for byte readers (#236)
More than 20% additional speedup: ``` λ benchcmp old.txt new.txt benchmark old ns/op new ns/op delta BenchmarkGunzipCopy-32 19698008 16125274 -18.14% BenchmarkGunzipNoWriteTo-32 19883807 16071205 -19.17% BenchmarkGunzipStdlib-32 23397455 23201339 -0.84% benchmark old MB/s new MB/s speedup BenchmarkGunzipCopy-32 242.31 295.99 1.22x BenchmarkGunzipNoWriteTo-32 240.04 296.99 1.24x BenchmarkGunzipStdlib-32 203.99 205.72 1.01x ```
1 parent fd5b254 commit 8abd495

File tree

3 files changed

+1201
-5
lines changed

3 files changed

+1201
-5
lines changed

flate/gen_inflate.go

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
// +build generate
2+
3+
//go:generate go run $GOFILE && gofmt -w inflate_gen.go
4+
5+
package main
6+
7+
import (
8+
"os"
9+
"strings"
10+
)
11+
12+
func main() {
13+
f, err := os.Create("inflate_gen.go")
14+
if err != nil {
15+
panic(err)
16+
}
17+
defer f.Close()
18+
types := []string{"*bytes.Buffer", "*bytes.Reader", "*bufio.Reader", "*strings.Reader"}
19+
names := []string{"BytesBuffer", "BytesReader", "BufioReader", "StringsReader"}
20+
imports := []string{"bytes", "bufio", "io", "strings", "math/bits"}
21+
f.WriteString(`// Code generated by go generate gen_inflate.go. DO NOT EDIT.
22+
23+
package flate
24+
25+
import (
26+
`)
27+
28+
for _, imp := range imports {
29+
f.WriteString("\t\"" + imp + "\"\n")
30+
}
31+
f.WriteString(")\n\n")
32+
33+
template := `
34+
35+
// Decode a single Huffman block from f.
36+
// hl and hd are the Huffman states for the lit/length values
37+
// and the distance values, respectively. If hd == nil, using the
38+
// fixed distance encoding associated with fixed Huffman blocks.
39+
func (f *decompressor) $FUNCNAME$() {
40+
const (
41+
stateInit = iota // Zero value must be stateInit
42+
stateDict
43+
)
44+
fr := f.r.($TYPE$)
45+
moreBits := func() error {
46+
c, err := fr.ReadByte()
47+
if err != nil {
48+
return noEOF(err)
49+
}
50+
f.roffset++
51+
f.b |= uint32(c) << f.nb
52+
f.nb += 8
53+
return nil
54+
}
55+
56+
switch f.stepState {
57+
case stateInit:
58+
goto readLiteral
59+
case stateDict:
60+
goto copyHistory
61+
}
62+
63+
readLiteral:
64+
// Read literal and/or (length, distance) according to RFC section 3.2.3.
65+
{
66+
var v int
67+
{
68+
// Inlined v, err := f.huffSym(f.hl)
69+
// Since a huffmanDecoder can be empty or be composed of a degenerate tree
70+
// with single element, huffSym must error on these two edge cases. In both
71+
// cases, the chunks slice will be 0 for the invalid sequence, leading it
72+
// satisfy the n == 0 check below.
73+
n := uint(f.hl.maxRead)
74+
// Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers,
75+
// but is smart enough to keep local variables in registers, so use nb and b,
76+
// inline call to moreBits and reassign b,nb back to f on return.
77+
nb, b := f.nb, f.b
78+
for {
79+
for nb < n {
80+
c, err := fr.ReadByte()
81+
if err != nil {
82+
f.b = b
83+
f.nb = nb
84+
f.err = noEOF(err)
85+
return
86+
}
87+
f.roffset++
88+
b |= uint32(c) << (nb & 31)
89+
nb += 8
90+
}
91+
chunk := f.hl.chunks[b&(huffmanNumChunks-1)]
92+
n = uint(chunk & huffmanCountMask)
93+
if n > huffmanChunkBits {
94+
chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask]
95+
n = uint(chunk & huffmanCountMask)
96+
}
97+
if n <= nb {
98+
if n == 0 {
99+
f.b = b
100+
f.nb = nb
101+
if debugDecode {
102+
fmt.Println("huffsym: n==0")
103+
}
104+
f.err = CorruptInputError(f.roffset)
105+
return
106+
}
107+
f.b = b >> (n & 31)
108+
f.nb = nb - n
109+
v = int(chunk >> huffmanValueShift)
110+
break
111+
}
112+
}
113+
}
114+
115+
var n uint // number of bits extra
116+
var length int
117+
var err error
118+
switch {
119+
case v < 256:
120+
f.dict.writeByte(byte(v))
121+
if f.dict.availWrite() == 0 {
122+
f.toRead = f.dict.readFlush()
123+
f.step = (*decompressor).$FUNCNAME$
124+
f.stepState = stateInit
125+
return
126+
}
127+
goto readLiteral
128+
case v == 256:
129+
f.finishBlock()
130+
return
131+
// otherwise, reference to older data
132+
case v < 265:
133+
length = v - (257 - 3)
134+
n = 0
135+
case v < 269:
136+
length = v*2 - (265*2 - 11)
137+
n = 1
138+
case v < 273:
139+
length = v*4 - (269*4 - 19)
140+
n = 2
141+
case v < 277:
142+
length = v*8 - (273*8 - 35)
143+
n = 3
144+
case v < 281:
145+
length = v*16 - (277*16 - 67)
146+
n = 4
147+
case v < 285:
148+
length = v*32 - (281*32 - 131)
149+
n = 5
150+
case v < maxNumLit:
151+
length = 258
152+
n = 0
153+
default:
154+
if debugDecode {
155+
fmt.Println(v, ">= maxNumLit")
156+
}
157+
f.err = CorruptInputError(f.roffset)
158+
return
159+
}
160+
if n > 0 {
161+
for f.nb < n {
162+
if err = moreBits(); err != nil {
163+
if debugDecode {
164+
fmt.Println("morebits n>0:", err)
165+
}
166+
f.err = err
167+
return
168+
}
169+
}
170+
length += int(f.b & uint32(1<<n-1))
171+
f.b >>= n
172+
f.nb -= n
173+
}
174+
175+
var dist int
176+
if f.hd == nil {
177+
for f.nb < 5 {
178+
if err = moreBits(); err != nil {
179+
if debugDecode {
180+
fmt.Println("morebits f.nb<5:", err)
181+
}
182+
f.err = err
183+
return
184+
}
185+
}
186+
dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3)))
187+
f.b >>= 5
188+
f.nb -= 5
189+
} else {
190+
if dist, err = f.huffSym(f.hd); err != nil {
191+
if debugDecode {
192+
fmt.Println("huffsym:", err)
193+
}
194+
f.err = err
195+
return
196+
}
197+
}
198+
199+
switch {
200+
case dist < 4:
201+
dist++
202+
case dist < maxNumDist:
203+
nb := uint(dist-2) >> 1
204+
// have 1 bit in bottom of dist, need nb more.
205+
extra := (dist & 1) << nb
206+
for f.nb < nb {
207+
if err = moreBits(); err != nil {
208+
if debugDecode {
209+
fmt.Println("morebits f.nb<nb:", err)
210+
}
211+
f.err = err
212+
return
213+
}
214+
}
215+
extra |= int(f.b & uint32(1<<nb-1))
216+
f.b >>= nb
217+
f.nb -= nb
218+
dist = 1<<(nb+1) + 1 + extra
219+
default:
220+
if debugDecode {
221+
fmt.Println("dist too big:", dist, maxNumDist)
222+
}
223+
f.err = CorruptInputError(f.roffset)
224+
return
225+
}
226+
227+
// No check on length; encoding can be prescient.
228+
if dist > f.dict.histSize() {
229+
if debugDecode {
230+
fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize())
231+
}
232+
f.err = CorruptInputError(f.roffset)
233+
return
234+
}
235+
236+
f.copyLen, f.copyDist = length, dist
237+
goto copyHistory
238+
}
239+
240+
copyHistory:
241+
// Perform a backwards copy according to RFC section 3.2.3.
242+
{
243+
cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen)
244+
if cnt == 0 {
245+
cnt = f.dict.writeCopy(f.copyDist, f.copyLen)
246+
}
247+
f.copyLen -= cnt
248+
249+
if f.dict.availWrite() == 0 || f.copyLen > 0 {
250+
f.toRead = f.dict.readFlush()
251+
f.step = (*decompressor).$FUNCNAME$ // We need to continue this work
252+
f.stepState = stateDict
253+
return
254+
}
255+
goto readLiteral
256+
}
257+
}
258+
259+
`
260+
for i, t := range types {
261+
s := strings.Replace(template, "$FUNCNAME$", "huffman"+names[i], -1)
262+
s = strings.Replace(s, "$TYPE$", t, -1)
263+
f.WriteString(s)
264+
}
265+
f.WriteString("func (f *decompressor) huffmanBlockDecoder() func() {\n")
266+
f.WriteString("\tswitch f.r.(type) {\n")
267+
for i, t := range types {
268+
f.WriteString("\t\tcase " + t + ":\n")
269+
f.WriteString("\t\t\treturn f.huffman" + names[i] + "\n")
270+
}
271+
f.WriteString("\t\tdefault:\n")
272+
f.WriteString("\t\t\treturn f.huffmanBlockGeneric")
273+
f.WriteString("\t}\n}\n")
274+
}

flate/inflate.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,15 @@ func (f *decompressor) nextBlock() {
342342
// compressed, fixed Huffman tables
343343
f.hl = &fixedHuffmanDecoder
344344
f.hd = nil
345-
f.huffmanBlock()
345+
f.huffmanBlockDecoder()()
346346
case 2:
347347
// compressed, dynamic Huffman tables
348348
if f.err = f.readHuffman(); f.err != nil {
349349
break
350350
}
351351
f.hl = &f.h1
352352
f.hd = &f.h2
353-
f.huffmanBlock()
353+
f.huffmanBlockDecoder()()
354354
default:
355355
// 3 is reserved.
356356
if debugDecode {
@@ -564,7 +564,7 @@ func (f *decompressor) readHuffman() error {
564564
// hl and hd are the Huffman states for the lit/length values
565565
// and the distance values, respectively. If hd == nil, using the
566566
// fixed distance encoding associated with fixed Huffman blocks.
567-
func (f *decompressor) huffmanBlock() {
567+
func (f *decompressor) huffmanBlockGeneric() {
568568
const (
569569
stateInit = iota // Zero value must be stateInit
570570
stateDict
@@ -637,7 +637,7 @@ readLiteral:
637637
f.dict.writeByte(byte(v))
638638
if f.dict.availWrite() == 0 {
639639
f.toRead = f.dict.readFlush()
640-
f.step = (*decompressor).huffmanBlock
640+
f.step = (*decompressor).huffmanBlockGeneric
641641
f.stepState = stateInit
642642
return
643643
}
@@ -765,7 +765,7 @@ copyHistory:
765765

766766
if f.dict.availWrite() == 0 || f.copyLen > 0 {
767767
f.toRead = f.dict.readFlush()
768-
f.step = (*decompressor).huffmanBlock // We need to continue this work
768+
f.step = (*decompressor).huffmanBlockGeneric // We need to continue this work
769769
f.stepState = stateDict
770770
return
771771
}

0 commit comments

Comments
 (0)