diff --git a/s2/cmd/s2d/main.go b/s2/cmd/s2d/main.go index ef594356fc..cebfab67f3 100644 --- a/s2/cmd/s2d/main.go +++ b/s2/cmd/s2d/main.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http" "os" + "runtime" "runtime/debug" "strconv" "strings" @@ -34,6 +35,7 @@ var ( help = flag.Bool("help", false, "Display help") out = flag.String("o", "", "Write output to another file. Single input file only") block = flag.Bool("block", false, "Decompress as a single block. Will load content into memory.") + cpu = flag.Int("cpu", runtime.NumCPU(), "Decompress streams using this amount of threads") version = "(dev)" date = "(unknown)" @@ -160,7 +162,11 @@ Options:`) output = int64(len(dec)) } else { r.Reset(bytes.NewBuffer(b)) - output, err = io.Copy(ioutil.Discard, r) + if *cpu > 1 { + output, err = r.DecodeConcurrent(ioutil.Discard, *cpu) + } else { + output, err = io.Copy(ioutil.Discard, r) + } exitErr(err) } if !*quiet { @@ -286,7 +292,13 @@ Options:`) } decoded = r } - output, err := io.Copy(out, decoded) + var err error + var output int64 + if dec, ok := decoded.(*s2.Reader); ok && tailBytes == 0 && offset == 0 { + output, err = dec.DecodeConcurrent(out, *cpu) + } else { + output, err = io.Copy(out, decoded) + } exitErr(err) if !*quiet { elapsed := time.Since(start) diff --git a/s2/decode.go b/s2/decode.go index 042a329493..0fa8f6deeb 100644 --- a/s2/decode.go +++ b/s2/decode.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "io/ioutil" + "runtime" + "sync" ) var ( @@ -196,13 +198,13 @@ type Reader struct { // ensureBufferSize will ensure that the buffer can take at least n bytes. // If false is returned the buffer exceeds maximum allowed size. func (r *Reader) ensureBufferSize(n int) bool { - if len(r.buf) >= n { - return true - } if n > r.maxBufSize { r.err = ErrCorrupt return false } + if cap(r.buf) >= n { + return true + } // Realloc buffer. r.buf = make([]byte, n) return true @@ -220,6 +222,7 @@ func (r *Reader) Reset(reader io.Reader) { r.err = nil r.i = 0 r.j = 0 + r.blockStart = 0 r.readHeader = r.ignoreStreamID } @@ -435,6 +438,253 @@ func (r *Reader) Read(p []byte) (int, error) { } } +// DecodeConcurrent will decode the full stream to w. +// This function should not be combined with reading, seeking or other operations. +// Up to 'concurrent' goroutines will be used. +// If <= 0, runtime.NumCPU will be used. +// On success the number of bytes decompressed nil and is returned. +// This is mainly intended for bigger streams. +func (r *Reader) DecodeConcurrent(w io.Writer, concurrent int) (written int64, err error) { + if r.i > 0 || r.j > 0 || r.blockStart > 0 { + return 0, errors.New("DecodeConcurrent called after ") + } + if concurrent <= 0 { + concurrent = runtime.NumCPU() + } + + // Write to output + var errMu sync.Mutex + setErr := func(e error) (ok bool) { + errMu.Lock() + defer errMu.Unlock() + if e == nil { + return err == nil + } + if err == nil { + err = e + } + return false + } + hasErr := func() (ok bool) { + errMu.Lock() + v := err != nil + errMu.Unlock() + return v + } + + toRead := make(chan []byte, concurrent) + writtenBlocks := make(chan []byte, concurrent) + queue := make(chan chan []byte, concurrent) + reUse := make(chan chan []byte, concurrent) + for i := 0; i < concurrent; i++ { + toRead <- make([]byte, 0, r.maxBufSize) + writtenBlocks <- make([]byte, 0, r.maxBufSize) + reUse <- make(chan []byte, 1) + } + // Writer + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for toWrite := range queue { + entry := <-toWrite + reUse <- toWrite + if hasErr() { + writtenBlocks <- entry + continue + } + n, err := w.Write(entry) + want := len(entry) + writtenBlocks <- entry + if err != nil { + setErr(err) + continue + } + if n != want { + setErr(io.ErrShortWrite) + continue + } + written += int64(n) + } + }() + + // Reader + defer func() { + close(queue) + wg.Wait() + if r.err != nil { + err = r.err + } + r.err = err + }() + + for !hasErr() { + if !r.readFull(r.buf[:4], true) { + if r.err == io.EOF { + r.err = nil + } + return written, r.err + } + chunkType := r.buf[0] + if !r.readHeader { + if chunkType != chunkTypeStreamIdentifier { + r.err = ErrCorrupt + return 0, r.err + } + r.readHeader = true + } + chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16 + + // The chunk types are specified at + // https://github.com/google/snappy/blob/master/framing_format.txt + switch chunkType { + case chunkTypeCompressedData: + r.blockStart += int64(r.j) + // Section 4.2. Compressed data (chunk type 0x00). + if chunkLen < checksumSize { + r.err = ErrCorrupt + return 0, r.err + } + if chunkLen > r.maxBufSize { + r.err = ErrCorrupt + return 0, r.err + } + orgBuf := <-toRead + buf := orgBuf[:chunkLen] + + if !r.readFull(buf, false) { + return 0, r.err + } + + checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 + buf = buf[checksumSize:] + + n, err := DecodedLen(buf) + if err != nil { + r.err = err + return 0, r.err + } + if r.snappyFrame && n > maxSnappyBlockSize { + r.err = ErrCorrupt + return 0, r.err + } + + if n > r.maxBlock { + r.err = ErrCorrupt + return 0, r.err + } + wg.Add(1) + + decoded := <-writtenBlocks + entry := <-reUse + queue <- entry + go func() { + defer wg.Done() + decoded = decoded[:n] + _, err := Decode(decoded, buf) + toRead <- orgBuf + if err != nil { + writtenBlocks <- decoded + setErr(err) + return + } + if crc(decoded) != checksum { + writtenBlocks <- decoded + setErr(ErrCRC) + return + } + entry <- decoded + }() + continue + + case chunkTypeUncompressedData: + + // Section 4.3. Uncompressed data (chunk type 0x01). + if chunkLen < checksumSize { + r.err = ErrCorrupt + return 0, r.err + } + if chunkLen > r.maxBufSize { + r.err = ErrCorrupt + return 0, r.err + } + // Grab write buffer + orgBuf := <-writtenBlocks + buf := orgBuf[:checksumSize] + if !r.readFull(buf, false) { + return 0, r.err + } + checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 + // Read content. + n := chunkLen - checksumSize + + if r.snappyFrame && n > maxSnappyBlockSize { + r.err = ErrCorrupt + return 0, r.err + } + if n > r.maxBlock { + r.err = ErrCorrupt + return 0, r.err + } + // Read uncompressed + buf = orgBuf[:n] + if !r.readFull(buf, false) { + return 0, r.err + } + + if crc(buf) != checksum { + r.err = ErrCRC + return 0, r.err + } + entry := <-reUse + queue <- entry + entry <- buf + continue + + case chunkTypeStreamIdentifier: + // Section 4.1. Stream identifier (chunk type 0xff). + if chunkLen != len(magicBody) { + r.err = ErrCorrupt + return 0, r.err + } + if !r.readFull(r.buf[:len(magicBody)], false) { + return 0, r.err + } + if string(r.buf[:len(magicBody)]) != magicBody { + if string(r.buf[:len(magicBody)]) != magicBodySnappy { + r.err = ErrCorrupt + return 0, r.err + } else { + r.snappyFrame = true + } + } else { + r.snappyFrame = false + } + continue + } + + if chunkType <= 0x7f { + // Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f). + // fmt.Printf("ERR chunktype: 0x%x\n", chunkType) + r.err = ErrUnsupported + return 0, r.err + } + // Section 4.4 Padding (chunk type 0xfe). + // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). + if chunkLen > maxChunkSize { + // fmt.Printf("ERR chunkLen: 0x%x\n", chunkLen) + r.err = ErrUnsupported + return 0, r.err + } + + // fmt.Printf("skippable: ID: 0x%x, len: 0x%x\n", chunkType, chunkLen) + if !r.skippable(r.buf, chunkLen, false, chunkType) { + return 0, r.err + } + } + return 0, r.err +} + // Skip will skip n bytes forward in the decompressed output. // For larger skips this consumes less CPU and is faster than reading output and discarding it. // CRC is not checked on skipped blocks.