Skip to content

Commit

Permalink
ARROW-18111: [Go] Remaining scalar binary arithmetic (shifts, power, …
Browse files Browse the repository at this point in the history
…bitwise) (#14703)

Authored-by: Matt Topol <zotthewizard@gmail.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
zeroshade authored Nov 23, 2022
1 parent ad54d6c commit 1121bbc
Show file tree
Hide file tree
Showing 11 changed files with 778 additions and 117 deletions.
133 changes: 133 additions & 0 deletions go/arrow/compute/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,41 @@ func (fn *arithmeticFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exe
return fn.DispatchExact(vals...)
}

type arithmeticDecimalToFloatingPointFunc struct {
arithmeticFunction
}

func (fn *arithmeticDecimalToFloatingPointFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) {
return execInternal(ctx, fn, opts, -1, args...)
}

func (fn *arithmeticDecimalToFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) {
if err := fn.checkArity(len(vals)); err != nil {
return nil, err
}

if kn, err := fn.DispatchExact(vals...); err == nil {
return kn, nil
}

ensureDictionaryDecoded(vals...)
if len(vals) == 2 {
replaceNullWithOtherType(vals...)
}

for i, t := range vals {
if arrow.IsDecimal(t.ID()) {
vals[i] = arrow.PrimitiveTypes.Float64
}
}

if dt := commonNumeric(vals...); dt != nil {
replaceTypes(dt, vals...)
}

return fn.DispatchExact(vals...)
}

var (
addDoc FunctionDoc
)
Expand Down Expand Up @@ -370,6 +405,68 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
}

reg.AddFunction(fn, false)

ops = []struct {
funcName string
op kernels.ArithmeticOp
decPromote decimalPromotion
}{
{"power_unchecked", kernels.OpPower, decPromoteNone},
{"power", kernels.OpPowerChecked, decPromoteNone},
}

for _, o := range ops {
fn := &arithmeticDecimalToFloatingPointFunc{arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), EmptyFuncDoc), o.decPromote}}
kns := kernels.GetArithmeticBinaryKernels(o.op)
for _, k := range kns {
if err := fn.AddKernel(k); err != nil {
panic(err)
}
}
reg.AddFunction(fn, false)
}

bitWiseOps := []struct {
funcName string
op kernels.BitwiseOp
}{
{"bit_wise_and", kernels.OpBitAnd},
{"bit_wise_or", kernels.OpBitOr},
{"bit_wise_xor", kernels.OpBitXor},
}

for _, o := range bitWiseOps {
fn := &arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), EmptyFuncDoc), decPromoteNone}
kns := kernels.GetBitwiseBinaryKernels(o.op)
for _, k := range kns {
if err := fn.AddKernel(k); err != nil {
panic(err)
}
}
reg.AddFunction(fn, false)
}

shiftOps := []struct {
funcName string
dir kernels.ShiftDir
checked bool
}{
{"shift_left", kernels.ShiftLeft, true},
{"shift_left_unchecked", kernels.ShiftLeft, false},
{"shift_right", kernels.ShiftRight, true},
{"shift_right_unchecked", kernels.ShiftRight, false},
}

for _, o := range shiftOps {
fn := &arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), EmptyFuncDoc), decPromoteNone}
kns := kernels.GetShiftKernels(o.dir, o.checked)
for _, k := range kns {
if err := fn.AddKernel(k); err != nil {
panic(err)
}
}
reg.AddFunction(fn, false)
}
}

func impl(ctx context.Context, fn string, opts ArithmeticOptions, left, right Datum) (Datum, error) {
Expand Down Expand Up @@ -463,3 +560,39 @@ func Negate(ctx context.Context, opts ArithmeticOptions, input Datum) (Datum, er
func Sign(ctx context.Context, input Datum) (Datum, error) {
return CallFunction(ctx, "sign", nil, input)
}

// Power returns base**exp for each element in the input arrays. Should work
// for both Arrays and Scalars
func Power(ctx context.Context, opts ArithmeticOptions, base, exp Datum) (Datum, error) {
fn := "power"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, base, exp)
}

// ShiftLeft only accepts integral types and shifts each element of the
// first argument to the left by the value of the corresponding element
// in the second argument.
//
// The value to shift by should be >= 0 and < precision of the type.
func ShiftLeft(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Datum, error) {
fn := "shift_left"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, lhs, rhs)
}

// ShiftRight only accepts integral types and shifts each element of the
// first argument to the right by the value of the corresponding element
// in the second argument.
//
// The value to shift by should be >= 0 and < precision of the type.
func ShiftRight(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Datum, error) {
fn := "shift_right"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, lhs, rhs)
}
Loading

0 comments on commit 1121bbc

Please sign in to comment.