From 07eae6f6111bbe156919463662280e5e2eac6113 Mon Sep 17 00:00:00 2001 From: benibus Date: Thu, 14 Sep 2023 12:27:41 -0400 Subject: [PATCH] Support read/write to/from `arrow.FLOAT16` --- go/parquet/file/column_writer_types.gen.go | 14 +++++++-- .../file/column_writer_types.gen.go.tmpl | 30 +++++++++++++++---- go/parquet/pqarrow/column_readers.go | 16 ++++++++++ go/parquet/pqarrow/encode_arrow.go | 25 ++++++++++++++++ go/parquet/pqarrow/encode_arrow_test.go | 9 ++++++ 5 files changed, 86 insertions(+), 8 deletions(-) diff --git a/go/parquet/file/column_writer_types.gen.go b/go/parquet/file/column_writer_types.gen.go index 5594f63249fb8..d0d042bcfbeb7 100644 --- a/go/parquet/file/column_writer_types.gen.go +++ b/go/parquet/file/column_writer_types.gen.go @@ -1629,7 +1629,12 @@ func (w *FixedLenByteArrayColumnChunkWriter) WriteDictIndices(indices arrow.Arra func (w *FixedLenByteArrayColumnChunkWriter) writeValues(values []parquet.FixedLenByteArray, numNulls int64) { w.currentEncoder.(encoding.FixedLenByteArrayEncoder).Put(values) if w.pageStatistics != nil { - w.pageStatistics.(*metadata.FixedLenByteArrayStatistics).Update(values, numNulls) + s, ok := w.pageStatistics.(*metadata.FixedLenByteArrayStatistics) + if ok { + s.Update(values, numNulls) + } else { + w.pageStatistics.(*metadata.Float16Statistics).Update(values, numNulls) + } } } @@ -1641,7 +1646,12 @@ func (w *FixedLenByteArrayColumnChunkWriter) writeValuesSpaced(spacedValues []pa } if w.pageStatistics != nil { nulls := numValues - numRead - w.pageStatistics.(*metadata.FixedLenByteArrayStatistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) + s, ok := w.pageStatistics.(*metadata.FixedLenByteArrayStatistics) + if ok { + s.UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) + } else { + w.pageStatistics.(*metadata.Float16Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) + } } } diff --git a/go/parquet/file/column_writer_types.gen.go.tmpl b/go/parquet/file/column_writer_types.gen.go.tmpl index c00e1dabb5fe6..dfc107ae5eef2 100644 --- a/go/parquet/file/column_writer_types.gen.go.tmpl +++ b/go/parquet/file/column_writer_types.gen.go.tmpl @@ -18,7 +18,7 @@ package file import ( "fmt" - + "github.com/apache/arrow/go/v14/parquet" "github.com/apache/arrow/go/v14/parquet/metadata" "github.com/apache/arrow/go/v14/parquet/internal/encoding" @@ -83,7 +83,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteBatch(values []{{.name}}, defLevels, r // writes a large number of values, the DataPage size can be much above the limit. // The purpose of this chunking is to bound this. Even if a user writes large number // of values, the chunking will ensure the AddDataPage() is called at a reasonable - // pagesize limit + // pagesize limit var n int64 switch { case defLevels != nil: @@ -107,7 +107,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteBatch(values []{{.name}}, defLevels, r valueOffset += toWrite w.checkDictionarySizeLimit() }) - return + return } // WriteBatchSpaced writes a batch of repetition levels, definition levels, and values to the @@ -132,7 +132,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteBatchSpaced(values []{{.name}}, defLev length = len(values) } doBatches(int64(length), w.props.WriteBatchSize(), func(offset, batch int64) { - var vals []{{.name}} + var vals []{{.name}} info := w.maybeCalculateValidityBits(levelSliceOrNil(defLevels, offset, batch), batch) w.writeLevelsSpaced(batch, levelSliceOrNil(defLevels, offset, batch), levelSliceOrNil(repLevels, offset, batch)) @@ -165,7 +165,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteDictIndices(indices arrow.Array, defLe } } }() - + valueOffset := int64(0) length := len(defLevels) if defLevels == nil { @@ -193,14 +193,23 @@ func (w *{{.Name}}ColumnChunkWriter) WriteDictIndices(indices arrow.Array, defLe valueOffset += info.numSpaced() }) - + return } func (w *{{.Name}}ColumnChunkWriter) writeValues(values []{{.name}}, numNulls int64) { w.currentEncoder.(encoding.{{.Name}}Encoder).Put(values) if w.pageStatistics != nil { +{{- if ne .Name "FixedLenByteArray"}} w.pageStatistics.(*metadata.{{.Name}}Statistics).Update(values, numNulls) +{{- else}} + s, ok := w.pageStatistics.(*metadata.{{.Name}}Statistics) + if ok { + s.Update(values, numNulls) + } else { + w.pageStatistics.(*metadata.Float16Statistics).Update(values, numNulls) + } +{{- end}} } } @@ -212,7 +221,16 @@ func (w *{{.Name}}ColumnChunkWriter) writeValuesSpaced(spacedValues []{{.name}}, } if w.pageStatistics != nil { nulls := numValues - numRead +{{- if ne .Name "FixedLenByteArray"}} w.pageStatistics.(*metadata.{{.Name}}Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) +{{- else}} + s, ok := w.pageStatistics.(*metadata.{{.Name}}Statistics) + if ok { + s.UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) + } else { + w.pageStatistics.(*metadata.Float16Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) + } +{{- end}} } } diff --git a/go/parquet/pqarrow/column_readers.go b/go/parquet/pqarrow/column_readers.go index 759a3d8675927..ebd7cc5388a36 100644 --- a/go/parquet/pqarrow/column_readers.go +++ b/go/parquet/pqarrow/column_readers.go @@ -517,6 +517,14 @@ func transferColumnData(rdr file.RecordReader, valueType arrow.DataType, descr * default: return nil, errors.New("time unit not supported") } + case arrow.FLOAT16: + if descr.PhysicalType() != parquet.Types.FixedLenByteArray { + return nil, errors.New("physical type for float16 must be fixed len byte array") + } + if len := arrow.Float16SizeBytes; descr.TypeLength() != len { + return nil, fmt.Errorf("fixed len byte array length for float16 must be %d", len) + } + return transferBinary(rdr, valueType), nil default: return nil, fmt.Errorf("no support for reading columns of type: %s", valueType.Name()) } @@ -563,6 +571,14 @@ func transferBinary(rdr file.RecordReader, dt arrow.DataType) *arrow.Chunked { chunks[idx] = array.MakeFromData(chunk.Data()) chunk.Release() } + case *arrow.Float16Type: + for idx, chunk := range chunks { + data := chunk.Data() + f16_data := array.NewData(dt, data.Len(), data.Buffers(), nil, data.NullN(), data.Offset()) + defer f16_data.Release() + chunks[idx] = array.NewFloat16Data(f16_data) + chunk.Release() + } } return arrow.NewChunked(dt, chunks) } diff --git a/go/parquet/pqarrow/encode_arrow.go b/go/parquet/pqarrow/encode_arrow.go index c3a0a50c43f45..2e25adce3c25b 100644 --- a/go/parquet/pqarrow/encode_arrow.go +++ b/go/parquet/pqarrow/encode_arrow.go @@ -582,6 +582,31 @@ func writeDenseArrow(ctx *arrowWriteContext, cw file.ColumnChunkWriter, leafArr } wr.WriteBatchSpaced(data, defLevels, repLevels, arr.NullBitmapBytes(), int64(arr.Data().Offset())) } + case *arrow.Float16Type: + typeLen := wr.Descr().TypeLength() + if typeLen != arrow.Float16SizeBytes { + return fmt.Errorf("%w: invalid FixedLenByteArray length to write from float16 column: %d", arrow.ErrInvalid, typeLen) + } + + arr := leafArr.(*array.Float16) + rawValues := arrow.Float16Traits.CastToBytes(arr.Values()) + data := make([]parquet.FixedLenByteArray, arr.Len()) + + if arr.NullN() == 0 { + for idx := range data { + offset := idx * typeLen + data[idx] = rawValues[offset : offset+typeLen] + } + _, err = wr.WriteBatch(data, defLevels, repLevels) + } else { + for idx := range data { + if arr.IsValid(idx) { + offset := idx * typeLen + data[idx] = rawValues[offset : offset+typeLen] + } + } + wr.WriteBatchSpaced(data, defLevels, repLevels, arr.NullBitmapBytes(), int64(arr.Data().Offset())) + } default: return fmt.Errorf("%w: invalid column type to write to FixedLenByteArray: %s", arrow.ErrInvalid, leafArr.DataType().Name()) } diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 654d3d813cf85..c96a4a56e70ef 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -450,6 +450,8 @@ func getLogicalType(typ arrow.DataType) schema.LogicalType { return schema.DateLogicalType{} case arrow.DATE64: return schema.DateLogicalType{} + case arrow.FLOAT16: + return schema.Float16LogicalType{} case arrow.TIMESTAMP: ts := typ.(*arrow.TimestampType) adjustedUTC := len(ts.TimeZone) == 0 @@ -496,6 +498,8 @@ func getPhysicalType(typ arrow.DataType) parquet.Type { return parquet.Types.Float case arrow.FLOAT64: return parquet.Types.Double + case arrow.FLOAT16: + return parquet.Types.FixedLenByteArray case arrow.BINARY, arrow.LARGE_BINARY, arrow.STRING, arrow.LARGE_STRING: return parquet.Types.ByteArray case arrow.FIXED_SIZE_BINARY, arrow.DECIMAL: @@ -555,6 +559,8 @@ func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.R byteWidth = int32(typ.ByteWidth) case arrow.DecimalType: byteWidth = pqarrow.DecimalSize(typ.GetPrecision()) + case *arrow.Float16Type: + byteWidth = int32(typ.Bytes()) case *arrow.DictionaryType: valuesType := typ.ValueType switch dt := valuesType.(type) { @@ -562,6 +568,8 @@ func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.R byteWidth = int32(dt.ByteWidth) case arrow.DecimalType: byteWidth = pqarrow.DecimalSize(dt.GetPrecision()) + case *arrow.Float16Type: + byteWidth = int32(typ.Bytes()) } } @@ -1068,6 +1076,7 @@ var fullTypeList = []arrow.DataType{ arrow.FixedWidthTypes.Date32, arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64, + arrow.FixedWidthTypes.Float16, arrow.BinaryTypes.String, arrow.BinaryTypes.Binary, &arrow.FixedSizeBinaryType{ByteWidth: 10},