Skip to content

Commit

Permalink
GH-43837: [Go][IPC] Consolidate StreamWriter and FileWriter, ensuring…
Browse files Browse the repository at this point in the history
… that EOS indicator is written in file (#43890)

### Rationale for this change

Fixes: #43837

Much of the logic between the ipc stream writer and the file writer was split. This PR changes the file writer so that it uses a stream writer internally, ensuring that a valid stream is embedded within the file.

**TODO**
- [x] Remove @ bkietz's commits

### What changes are included in this PR?

- Refactor `fileWriter` to embed `streamWriter` and defer relevant methods
- Add test

### Are these changes tested?

Yes

### Are there any user-facing changes?

Go-generated IPC files will contain the EOS indicator

* GitHub Issue: #43837

Authored-by: Joel Lubinitsky <joellubi@gmail.com>
Signed-off-by: Joel Lubinitsky <joellubi@gmail.com>
  • Loading branch information
joellubi authored Aug 30, 2024
1 parent 07420b0 commit 63b34c9
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 69 deletions.
40 changes: 40 additions & 0 deletions go/arrow/ipc/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
package ipc_test

import (
"bytes"
"fmt"
"os"
"testing"

"github.com/apache/arrow/go/v18/arrow/array"
"github.com/apache/arrow/go/v18/arrow/internal/arrdata"
"github.com/apache/arrow/go/v18/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v18/arrow/ipc"
"github.com/apache/arrow/go/v18/arrow/memory"
"github.com/stretchr/testify/require"
)

func TestFile(t *testing.T) {
Expand Down Expand Up @@ -75,3 +79,39 @@ func TestFileCompressed(t *testing.T) {
}
}
}

func TestFileEmbedsStream(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

recs := arrdata.Records["primitives"]
schema := recs[0].Schema()

var buf bytes.Buffer
w, err := ipc.NewFileWriter(&buf, ipc.WithSchema(schema), ipc.WithAllocator(mem))
require.NoError(t, err)
defer w.Close()

for _, rec := range recs {
require.NoError(t, w.Write(rec))
}

require.NoError(t, w.Close())

// we should be able to read a valid ipc stream within the ipc file

// create an ipc stream reader, skipping the file magic+padding bytes
rdr, err := ipc.NewReader(bytes.NewReader(buf.Bytes()[8:]), ipc.WithSchema(schema), ipc.WithAllocator(mem))
require.NoError(t, err)
defer rdr.Release()

// the stream reader should know to stop before the footer if the EOS indicator is properly written
var i int
for rdr.Next() {
rec := rdr.Record()
require.Truef(t, array.RecordEqual(rec, recs[i]), "records[%d] differ", i)
i++
}

require.NoError(t, rdr.Err())
}
82 changes: 19 additions & 63 deletions go/arrow/ipc/file_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,17 @@ type PayloadWriter interface {
Close() error
}

type pwriter struct {
w io.WriteSeeker
pos int64
type fileWriter struct {
streamWriter

schema *arrow.Schema
dicts []fileBlock
recs []fileBlock
}

func (w *pwriter) Start() error {
func (w *fileWriter) Start() error {
var err error

err = w.updatePos()
if err != nil {
return fmt.Errorf("arrow/ipc: could not update position while in start: %w", err)
}

// only necessary to align to 8-byte boundary at the start of the file
_, err = w.Write(Magic)
if err != nil {
Expand All @@ -65,10 +59,10 @@ func (w *pwriter) Start() error {
return fmt.Errorf("arrow/ipc: could not align start block: %w", err)
}

return err
return w.streamWriter.Start()
}

func (w *pwriter) WritePayload(p Payload) error {
func (w *fileWriter) WritePayload(p Payload) error {
blk := fileBlock{Offset: w.pos, Meta: 0, Body: p.size}
n, err := writeIPCPayload(w, p)
if err != nil {
Expand All @@ -77,11 +71,6 @@ func (w *pwriter) WritePayload(p Payload) error {

blk.Meta = int32(n)

err = w.updatePos()
if err != nil {
return fmt.Errorf("arrow/ipc: could not update position while in write-payload: %w", err)
}

switch flatbuf.MessageHeader(p.msg) {
case flatbuf.MessageHeaderDictionaryBatch:
w.dicts = append(w.dicts, blk)
Expand All @@ -92,27 +81,18 @@ func (w *pwriter) WritePayload(p Payload) error {
return nil
}

func (w *pwriter) Close() error {
func (w *fileWriter) Close() error {
var err error

// write file footer
err = w.updatePos()
if err != nil {
return fmt.Errorf("arrow/ipc: could not update position while in close: %w", err)
if err = w.streamWriter.Close(); err != nil {
return err
}

pos := w.pos
err = writeFileFooter(w.schema, w.dicts, w.recs, w)
if err != nil {
if err = writeFileFooter(w.schema, w.dicts, w.recs, w); err != nil {
return fmt.Errorf("arrow/ipc: could not write file footer: %w", err)
}

// write file footer length
err = w.updatePos() // not strictly needed as we passed w to writeFileFooter...
if err != nil {
return fmt.Errorf("arrow/ipc: could not compute file footer length: %w", err)
}

size := w.pos - pos
if size <= 0 {
return fmt.Errorf("arrow/ipc: invalid file footer size (size=%d)", size)
Expand All @@ -133,13 +113,7 @@ func (w *pwriter) Close() error {
return nil
}

func (w *pwriter) updatePos() error {
var err error
w.pos, err = w.w.Seek(0, io.SeekCurrent)
return err
}

func (w *pwriter) align(align int32) error {
func (w *fileWriter) align(align int32) error {
remainder := paddedLength(w.pos, align) - w.pos
if remainder == 0 {
return nil
Expand All @@ -149,12 +123,6 @@ func (w *pwriter) align(align int32) error {
return err
}

func (w *pwriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.pos += int64(n)
return n, err
}

func writeIPCPayload(w io.Writer, p Payload) (int, error) {
n, err := writeMessage(p.meta, kArrowIPCAlignment, w)
if err != nil {
Expand Down Expand Up @@ -259,18 +227,12 @@ func (ps payloads) Release() {

// FileWriter is an Arrow file writer.
type FileWriter struct {
w io.WriteSeeker
w io.Writer

mem memory.Allocator

header struct {
started bool
offset int64
}

footer struct {
written bool
}
headerStarted bool
footerWritten bool

pw PayloadWriter

Expand All @@ -289,15 +251,15 @@ type FileWriter struct {
}

// NewFileWriter opens an Arrow file using the provided writer w.
func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) {
func NewFileWriter(w io.Writer, opts ...Option) (*FileWriter, error) {
var (
cfg = newConfig(opts...)
err error
)

f := FileWriter{
w: w,
pw: &pwriter{w: w, schema: cfg.schema, pos: -1},
pw: &fileWriter{streamWriter: streamWriter{w: w}, schema: cfg.schema},
mem: cfg.alloc,
schema: cfg.schema,
codec: cfg.codec,
Expand All @@ -306,12 +268,6 @@ func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) {
compressors: make([]compressor, cfg.compressNP),
}

pos, err := f.w.Seek(0, io.SeekCurrent)
if err != nil {
return nil, fmt.Errorf("arrow/ipc: could not seek current position: %w", err)
}
f.header.offset = pos

return &f, err
}

Expand All @@ -321,15 +277,15 @@ func (f *FileWriter) Close() error {
return fmt.Errorf("arrow/ipc: could not write empty file: %w", err)
}

if f.footer.written {
if f.footerWritten {
return nil
}

err = f.pw.Close()
if err != nil {
return fmt.Errorf("arrow/ipc: could not close payload writer: %w", err)
}
f.footer.written = true
f.footerWritten = true

return nil
}
Expand Down Expand Up @@ -367,14 +323,14 @@ func (f *FileWriter) Write(rec arrow.Record) error {
}

func (f *FileWriter) checkStarted() error {
if !f.header.started {
if !f.headerStarted {
return f.start()
}
return nil
}

func (f *FileWriter) start() error {
f.header.started = true
f.headerStarted = true
err := f.pw.Start()
if err != nil {
return err
Expand Down
12 changes: 6 additions & 6 deletions go/arrow/ipc/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,26 @@ import (
"github.com/apache/arrow/go/v18/internal/utils"
)

type swriter struct {
type streamWriter struct {
w io.Writer
pos int64
}

func (w *swriter) Start() error { return nil }
func (w *swriter) Close() error {
func (w *streamWriter) Start() error { return nil }
func (w *streamWriter) Close() error {
_, err := w.Write(kEOS[:])
return err
}

func (w *swriter) WritePayload(p Payload) error {
func (w *streamWriter) WritePayload(p Payload) error {
_, err := writeIPCPayload(w, p)
if err != nil {
return err
}
return nil
}

func (w *swriter) Write(p []byte) (int, error) {
func (w *streamWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.pos += int64(n)
return n, err
Expand Down Expand Up @@ -118,7 +118,7 @@ func NewWriter(w io.Writer, opts ...Option) *Writer {
return &Writer{
w: w,
mem: cfg.alloc,
pw: &swriter{w: w},
pw: &streamWriter{w: w},
schema: cfg.schema,
codec: cfg.codec,
emitDictDeltas: cfg.emitDictDeltas,
Expand Down

0 comments on commit 63b34c9

Please sign in to comment.