Skip to content

Commit

Permalink
Support read/write to/from arrow.FLOAT16
Browse files Browse the repository at this point in the history
  • Loading branch information
benibus committed Sep 14, 2023
1 parent 5e8e392 commit 07eae6f
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 8 deletions.
14 changes: 12 additions & 2 deletions go/parquet/file/column_writer_types.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 24 additions & 6 deletions go/parquet/file/column_writer_types.gen.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package file

import (
"fmt"

"github.com/apache/arrow/go/v14/parquet"
"github.com/apache/arrow/go/v14/parquet/metadata"
"github.com/apache/arrow/go/v14/parquet/internal/encoding"
Expand Down Expand Up @@ -83,7 +83,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteBatch(values []{{.name}}, defLevels, r
// writes a large number of values, the DataPage size can be much above the limit.
// The purpose of this chunking is to bound this. Even if a user writes large number
// of values, the chunking will ensure the AddDataPage() is called at a reasonable
// pagesize limit
// pagesize limit
var n int64
switch {
case defLevels != nil:
Expand All @@ -107,7 +107,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteBatch(values []{{.name}}, defLevels, r
valueOffset += toWrite
w.checkDictionarySizeLimit()
})
return
return
}

// WriteBatchSpaced writes a batch of repetition levels, definition levels, and values to the
Expand All @@ -132,7 +132,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteBatchSpaced(values []{{.name}}, defLev
length = len(values)
}
doBatches(int64(length), w.props.WriteBatchSize(), func(offset, batch int64) {
var vals []{{.name}}
var vals []{{.name}}
info := w.maybeCalculateValidityBits(levelSliceOrNil(defLevels, offset, batch), batch)

w.writeLevelsSpaced(batch, levelSliceOrNil(defLevels, offset, batch), levelSliceOrNil(repLevels, offset, batch))
Expand Down Expand Up @@ -165,7 +165,7 @@ func (w *{{.Name}}ColumnChunkWriter) WriteDictIndices(indices arrow.Array, defLe
}
}
}()

valueOffset := int64(0)
length := len(defLevels)
if defLevels == nil {
Expand Down Expand Up @@ -193,14 +193,23 @@ func (w *{{.Name}}ColumnChunkWriter) WriteDictIndices(indices arrow.Array, defLe

valueOffset += info.numSpaced()
})

return
}

func (w *{{.Name}}ColumnChunkWriter) writeValues(values []{{.name}}, numNulls int64) {
w.currentEncoder.(encoding.{{.Name}}Encoder).Put(values)
if w.pageStatistics != nil {
{{- if ne .Name "FixedLenByteArray"}}
w.pageStatistics.(*metadata.{{.Name}}Statistics).Update(values, numNulls)
{{- else}}
s, ok := w.pageStatistics.(*metadata.{{.Name}}Statistics)
if ok {
s.Update(values, numNulls)
} else {
w.pageStatistics.(*metadata.Float16Statistics).Update(values, numNulls)
}
{{- end}}
}
}

Expand All @@ -212,7 +221,16 @@ func (w *{{.Name}}ColumnChunkWriter) writeValuesSpaced(spacedValues []{{.name}},
}
if w.pageStatistics != nil {
nulls := numValues - numRead
{{- if ne .Name "FixedLenByteArray"}}
w.pageStatistics.(*metadata.{{.Name}}Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls)
{{- else}}
s, ok := w.pageStatistics.(*metadata.{{.Name}}Statistics)
if ok {
s.UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls)
} else {
w.pageStatistics.(*metadata.Float16Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls)
}
{{- end}}
}
}

Expand Down
16 changes: 16 additions & 0 deletions go/parquet/pqarrow/column_readers.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,14 @@ func transferColumnData(rdr file.RecordReader, valueType arrow.DataType, descr *
default:
return nil, errors.New("time unit not supported")
}
case arrow.FLOAT16:
if descr.PhysicalType() != parquet.Types.FixedLenByteArray {
return nil, errors.New("physical type for float16 must be fixed len byte array")
}
if len := arrow.Float16SizeBytes; descr.TypeLength() != len {
return nil, fmt.Errorf("fixed len byte array length for float16 must be %d", len)
}
return transferBinary(rdr, valueType), nil
default:
return nil, fmt.Errorf("no support for reading columns of type: %s", valueType.Name())
}
Expand Down Expand Up @@ -563,6 +571,14 @@ func transferBinary(rdr file.RecordReader, dt arrow.DataType) *arrow.Chunked {
chunks[idx] = array.MakeFromData(chunk.Data())
chunk.Release()
}
case *arrow.Float16Type:
for idx, chunk := range chunks {
data := chunk.Data()
f16_data := array.NewData(dt, data.Len(), data.Buffers(), nil, data.NullN(), data.Offset())
defer f16_data.Release()
chunks[idx] = array.NewFloat16Data(f16_data)
chunk.Release()
}
}
return arrow.NewChunked(dt, chunks)
}
Expand Down
25 changes: 25 additions & 0 deletions go/parquet/pqarrow/encode_arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,31 @@ func writeDenseArrow(ctx *arrowWriteContext, cw file.ColumnChunkWriter, leafArr
}
wr.WriteBatchSpaced(data, defLevels, repLevels, arr.NullBitmapBytes(), int64(arr.Data().Offset()))
}
case *arrow.Float16Type:
typeLen := wr.Descr().TypeLength()
if typeLen != arrow.Float16SizeBytes {
return fmt.Errorf("%w: invalid FixedLenByteArray length to write from float16 column: %d", arrow.ErrInvalid, typeLen)
}

arr := leafArr.(*array.Float16)
rawValues := arrow.Float16Traits.CastToBytes(arr.Values())
data := make([]parquet.FixedLenByteArray, arr.Len())

if arr.NullN() == 0 {
for idx := range data {
offset := idx * typeLen
data[idx] = rawValues[offset : offset+typeLen]
}
_, err = wr.WriteBatch(data, defLevels, repLevels)
} else {
for idx := range data {
if arr.IsValid(idx) {
offset := idx * typeLen
data[idx] = rawValues[offset : offset+typeLen]
}
}
wr.WriteBatchSpaced(data, defLevels, repLevels, arr.NullBitmapBytes(), int64(arr.Data().Offset()))
}
default:
return fmt.Errorf("%w: invalid column type to write to FixedLenByteArray: %s", arrow.ErrInvalid, leafArr.DataType().Name())
}
Expand Down
9 changes: 9 additions & 0 deletions go/parquet/pqarrow/encode_arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ func getLogicalType(typ arrow.DataType) schema.LogicalType {
return schema.DateLogicalType{}
case arrow.DATE64:
return schema.DateLogicalType{}
case arrow.FLOAT16:
return schema.Float16LogicalType{}
case arrow.TIMESTAMP:
ts := typ.(*arrow.TimestampType)
adjustedUTC := len(ts.TimeZone) == 0
Expand Down Expand Up @@ -496,6 +498,8 @@ func getPhysicalType(typ arrow.DataType) parquet.Type {
return parquet.Types.Float
case arrow.FLOAT64:
return parquet.Types.Double
case arrow.FLOAT16:
return parquet.Types.FixedLenByteArray
case arrow.BINARY, arrow.LARGE_BINARY, arrow.STRING, arrow.LARGE_STRING:
return parquet.Types.ByteArray
case arrow.FIXED_SIZE_BINARY, arrow.DECIMAL:
Expand Down Expand Up @@ -555,13 +559,17 @@ func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.R
byteWidth = int32(typ.ByteWidth)
case arrow.DecimalType:
byteWidth = pqarrow.DecimalSize(typ.GetPrecision())
case *arrow.Float16Type:
byteWidth = int32(typ.Bytes())
case *arrow.DictionaryType:
valuesType := typ.ValueType
switch dt := valuesType.(type) {
case *arrow.FixedSizeBinaryType:
byteWidth = int32(dt.ByteWidth)
case arrow.DecimalType:
byteWidth = pqarrow.DecimalSize(dt.GetPrecision())
case *arrow.Float16Type:
byteWidth = int32(typ.Bytes())
}
}

Expand Down Expand Up @@ -1068,6 +1076,7 @@ var fullTypeList = []arrow.DataType{
arrow.FixedWidthTypes.Date32,
arrow.PrimitiveTypes.Float32,
arrow.PrimitiveTypes.Float64,
arrow.FixedWidthTypes.Float16,
arrow.BinaryTypes.String,
arrow.BinaryTypes.Binary,
&arrow.FixedSizeBinaryType{ByteWidth: 10},
Expand Down

0 comments on commit 07eae6f

Please sign in to comment.