diff --git a/s2/writer.go b/s2/writer.go index abdcad905..fd15078f7 100644 --- a/s2/writer.go +++ b/s2/writer.go @@ -83,11 +83,14 @@ type Writer struct { snappy bool flushOnWrite bool appendIndex bool + bufferCB func([]byte) level uint8 } type result struct { b []byte + // return when writing + ret []byte // Uncompressed start offset startOffset int64 } @@ -146,6 +149,10 @@ func (w *Writer) Reset(writer io.Writer) { for write := range toWrite { // Wait for the data to be available. input := <-write + if input.ret != nil && w.bufferCB != nil { + w.bufferCB(input.ret) + input.ret = nil + } in := input.b if len(in) > 0 { if w.err(nil) == nil { @@ -341,7 +348,8 @@ func (w *Writer) AddSkippableBlock(id uint8, data []byte) (err error) { // but the input buffer cannot be written to by the caller // until Flush or Close has been called when concurrency != 1. // -// If you cannot control that, use the regular Write function. +// Use the WriterBufferDone to receive a callback when the buffer is done +// Processing. // // Note that input is not buffered. // This means that each write will result in discrete blocks being created. @@ -364,6 +372,9 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) { } if w.concurrency == 1 { _, err := w.writeSync(buf) + if w.bufferCB != nil { + w.bufferCB(buf) + } return err } @@ -378,7 +389,7 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) { hWriter <- result{startOffset: w.uncompWritten, b: magicChunkBytes} } } - + orgBuf := buf for len(buf) > 0 { // Cut input. uncompressed := buf @@ -397,6 +408,9 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) { startOffset: w.uncompWritten, } w.uncompWritten += int64(len(uncompressed)) + if len(buf) == 0 && w.bufferCB != nil { + res.ret = orgBuf + } go func() { race.ReadSlice(uncompressed) @@ -941,6 +955,17 @@ func WriterUncompressed() WriterOption { } } +// WriterBufferDone will perform a callback when EncodeBuffer has finished +// writing a buffer to the output and the buffer can safely be reused. +// If the buffer was split into several blocks, it will be sent after the last block. +// Callbacks will not be done concurrently. +func WriterBufferDone(fn func(b []byte)) WriterOption { + return func(w *Writer) error { + w.bufferCB = fn + return nil + } +} + // WriterBlockSize allows to override the default block size. // Blocks will be this size or smaller. // Minimum size is 4KB and maximum size is 4MB. diff --git a/s2/writer_test.go b/s2/writer_test.go index 470abbb80..a8b7585a1 100644 --- a/s2/writer_test.go +++ b/s2/writer_test.go @@ -576,6 +576,49 @@ func TestBigEncodeBufferSync(t *testing.T) { t.Log(n) } +func TestWriterBufferDone(t *testing.T) { + const blockSize = 1 << 20 + var buffers [][]byte + for _, size := range []int{10, 100, 10000, blockSize, blockSize * 8} { + buffers = append(buffers, make([]byte, size)) + } + + dst := io.Discard + wantNextBuf := 0 + var cbErr error + enc := NewWriter(dst, WriterBlockSize(blockSize), WriterConcurrency(4), WriterBufferDone(func(b []byte) { + if !bytes.Equal(b, buffers[wantNextBuf]) && cbErr == nil { + cbErr = fmt.Errorf("wrong buffer returned, want %v got %v", buffers[wantNextBuf], b) + } + // Detect races. + for i := range b[:] { + b[i] = 255 + } + wantNextBuf++ + })) + for n, buf := range buffers { + // Change the buffer to a new value. + for i := range buf[:] { + buf[i] = byte(n) + } + // Send the buffer + err := enc.EncodeBuffer(buf) + if err != nil { + t.Fatal(err) + } + } + err := enc.Close() + if err != nil { + t.Fatal(err) + } + if wantNextBuf != len(buffers) { + t.Fatalf("want %d buffers, got %d ", len(buffers), wantNextBuf) + } + if cbErr != nil { + t.Fatal(cbErr) + } +} + func BenchmarkWriterRandom(b *testing.B) { rng := rand.New(rand.NewSource(1)) // Make max window so we never get matches.