Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: compute fixes for extension types #171

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion arrow/compute/exprs/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,25 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
return nil, err
}

var newArgs []compute.Datum
// cast arguments if necessary
for i, arg := range args {
if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
if newArgs == nil {
newArgs = make([]compute.Datum, len(args))
copy(newArgs, args)
}
newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
if err != nil {
return nil, err
}
defer newArgs[i].Release()
}
}
if newArgs != nil {
args = newArgs
}

kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
init := k.GetInitFn()
kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, Options: opts}
Expand Down Expand Up @@ -613,7 +632,7 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
result.Release()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need result = nil here? (Can we return Release()-ed result?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, probably safer to make it nil you're right. just in case.

}

return result, nil
return result, err
}

return nil, arrow.ErrNotImplemented
Expand Down
13 changes: 11 additions & 2 deletions arrow/compute/exprs/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exprs"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/expr"
Expand Down Expand Up @@ -135,8 +137,12 @@ func TestComparisons(t *testing.T) {
one = scalar.MakeScalar(int32(1))
two = scalar.MakeScalar(int32(2))

str = scalar.MakeScalar("hello")
bin = scalar.MakeScalar([]byte("hello"))
str = scalar.MakeScalar("hello")
bin = scalar.MakeScalar([]byte("hello"))
exampleUUID = uuid.MustParse("102cb62f-e6f8-4eb0-9973-d9b012ff0967")
uidStorage, _ = scalar.MakeScalarParam(exampleUUID[:],
&arrow.FixedSizeBinaryType{ByteWidth: 16})
uid = scalar.NewExtensionScalar(uidStorage, extensions.NewUUIDType())
Comment on lines +143 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uid -> uuid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uuid is used as the name of the uuid package, so I'd end up with a conflict / issue if I use the name for the variable here. I can rename it as something else though if we prefer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! uuidStorage and uuidScalar may be better. (I feel that uid is "user id"...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do!

)

getArgType := func(dt arrow.DataType) types.Type {
Expand All @@ -147,6 +153,8 @@ func TestComparisons(t *testing.T) {
return &types.StringType{}
case arrow.BINARY:
return &types.BinaryType{}
case arrow.EXTENSION:
return &types.UUIDType{}
}
panic("wtf")
}
Expand Down Expand Up @@ -183,6 +191,7 @@ func TestComparisons(t *testing.T) {

expect(t, "equal", one, one, true)
expect(t, "equal", one, two, false)
expect(t, "equal", uid, uid, true)
expect(t, "less", one, two, true)
expect(t, "less", one, zero, false)
expect(t, "greater", one, zero, true)
Expand Down
8 changes: 6 additions & 2 deletions arrow/compute/exprs/extension_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (ef *simpleExtensionTypeFactory[P]) ExtensionEquals(other arrow.ExtensionTy
return ef.params == rhs.params
}
func (ef *simpleExtensionTypeFactory[P]) ArrayType() reflect.Type {
return reflect.TypeOf(array.ExtensionArrayBase{})
return reflect.TypeOf(simpleExtensionArrayFactory[P]{})
}

func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
Expand All @@ -91,10 +91,14 @@ func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
}
}

type simpleExtensionArrayFactory[P comparable] struct {
array.ExtensionArrayBase
}

type uuidExtParams struct{}

var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
name: "uuid", getStorage: func(uuidExtParams) arrow.DataType {
name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
return &arrow.FixedSizeBinaryType{ByteWidth: 16}
}}

Expand Down
1 change: 1 addition & 0 deletions arrow/compute/scalar_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (fn *compareFunction) DispatchBest(vals ...arrow.DataType) (exec.Kernel, er
}

ensureDictionaryDecoded(vals...)
ensureNotExtensionType(vals...)
replaceNullWithOtherType(vals...)

if dt := commonNumeric(vals...); dt != nil {
Expand Down
3 changes: 3 additions & 0 deletions arrow/compute/scalar_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/internal/testing/gen"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
Expand Down Expand Up @@ -1289,6 +1290,8 @@ func TestCompareKernelsDispatchBest(t *testing.T) {
&arrow.Decimal128Type{Precision: 3, Scale: 2}, &arrow.Decimal128Type{Precision: 21, Scale: 2}},
{arrow.PrimitiveTypes.Int64, &arrow.Decimal128Type{Precision: 3, Scale: 2},
&arrow.Decimal128Type{Precision: 21, Scale: 2}, &arrow.Decimal128Type{Precision: 3, Scale: 2}},

{extensions.NewUUIDType(), extensions.NewUUIDType(), &arrow.FixedSizeBinaryType{ByteWidth: 16}, &arrow.FixedSizeBinaryType{ByteWidth: 16}},
}

for _, name := range []string{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"} {
Expand Down
8 changes: 8 additions & 0 deletions arrow/compute/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ func ensureDictionaryDecoded(vals ...arrow.DataType) {
}
}

func ensureNotExtensionType(vals ...arrow.DataType) {
for i, v := range vals {
if v.ID() == arrow.EXTENSION {
vals[i] = v.(arrow.ExtensionType).StorageType()
}
}
}

func replaceNullWithOtherType(vals ...arrow.DataType) {
debug.Assert(len(vals) == 2, "should be length 2")

Expand Down
4 changes: 4 additions & 0 deletions arrow/extensions/uuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ func (*UUIDType) ExtensionName() string {
return "arrow.uuid"
}

func (*UUIDType) Bytes() int { return 16 }

func (*UUIDType) BitWidth() int { return 128 }

func (e *UUIDType) String() string {
return fmt.Sprintf("extension<%s>", e.ExtensionName())
}
Expand Down
2 changes: 2 additions & 0 deletions parquet/pqarrow/encode_arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,8 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() {
defer tbl.Release()

ps.roundTripTable(mem, tbl, true)
// ensure we get UUID back even without storing the schema
ps.roundTripTable(mem, tbl, false)
}

func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() {
Expand Down
19 changes: 12 additions & 7 deletions parquet/pqarrow/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/flight"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/parquet"
Expand Down Expand Up @@ -514,8 +515,10 @@ func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, erro
switch logtype := logical.(type) {
case schema.DecimalLogicalType:
return arrowDecimal(logtype), nil
case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType:
case schema.NoLogicalType, schema.IntervalLogicalType:
return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil
case schema.UUIDLogicalType:
return extensions.NewUUIDType(), nil
case schema.Float16LogicalType:
return &arrow.Float16Type{}, nil
default:
Expand Down Expand Up @@ -984,13 +987,15 @@ func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (mo
return
}

if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) {
return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'",
arrow.ErrInvalid, inferred.Field.Type, extType)
}
if modified || !arrow.TypeEqual(extType, inferred.Field.Type) {
if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) {
return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'",
arrow.ErrInvalid, inferred.Field.Type, extType)
}

inferred.Field.Type = extType
modified = true
inferred.Field.Type = extType
modified = true
}
case arrow.SPARSE_UNION, arrow.DENSE_UNION:
err = xerrors.New("unimplemented type")
case arrow.STRUCT:
Expand Down
Loading