Skip to content

Commit

Permalink
[C++][Gandiva] Adding more functions that are required by native SQL …
Browse files Browse the repository at this point in the history
…engine

Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>
  • Loading branch information
zhouyuan authored and zhztheplayer committed Feb 28, 2022
1 parent 9354fc4 commit 568e98d
Show file tree
Hide file tree
Showing 16 changed files with 921 additions and 15 deletions.
25 changes: 24 additions & 1 deletion cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,15 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, float32, int64),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, float64, int64),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, date64, int64),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, int8, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, int16, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, int64, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, date32, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, float32, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, float64, int32),
UNARY_SAFE_NULL_IF_NULL(castBYTE, {}, int16, int8),
UNARY_SAFE_NULL_IF_NULL(castBYTE, {}, int32, int8),
UNARY_SAFE_NULL_IF_NULL(castBYTE, {}, int64, int8),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, decimal128, int64),

// cast to float32
Expand All @@ -94,6 +99,10 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, decimal128, decimal128),
UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, {}, utf8, decimal128),

// isNaN
UNARY_SAFE_NULL_IF_NULL(isNaN, {}, float32, boolean),
UNARY_SAFE_NULL_IF_NULL(isNaN, {}, float64, boolean),

NativeFunction("castDECIMALNullOnOverflow", {}, DataTypeVector{decimal128()},
decimal128(), kResultNullInternal,
"castDECIMALNullOnOverflow_decimal128"),
Expand Down Expand Up @@ -160,6 +169,11 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(round, {}, int64, int64),
BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, int32, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, int64, int32, int64),
// bitwise functions
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_left, {}, int32, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_left, {}, int64, int32, int64),
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_right, {}, int32, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_right, {}, int64, int32, int64),

// bround functions
NativeFunction("bround", {}, DataTypeVector{float64()}, float64(),
Expand All @@ -172,12 +186,21 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_RELATIONAL_BOOL_DATE_FN(less_than_or_equal_to, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to, {}),

// compare functions with nan
BINARY_RELATIONAL_BOOL_FN(equal_with_nan, ({"eq_with_nan", "same_with_nan"})),
BINARY_RELATIONAL_BOOL_FN(not_equal_with_nan, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(less_than_with_nan, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(less_than_or_equal_to_with_nan, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_with_nan, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to_with_nan, {})};
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to, {}),
BASE_NUMERIC_TYPES(MULTIPLE_SAFE_NULL_IF_NULL, greatest, {}),
BASE_NUMERIC_TYPES(MULTIPLE_SAFE_NULL_IF_NULL, least, {}),

// binary representation of integer values
UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int32, utf8),
UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int64, utf8)};
UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int64, utf8);

return arithmetic_fn_registry_;
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/gandiva/function_registry_datetime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ std::vector<NativeFunction> GetDateTimeFunctionRegistry() {
kResultNullIfNull, "castTIMESTAMP_utf8",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castVARCHAR", {}, DataTypeVector{date32(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_date32_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{timestamp(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_timestamp_int64",
NativeFunction::kNeedsContext),
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/gandiva/function_registry_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ namespace gandiva {

std::vector<NativeFunction> GetHashFunctionRegistry() {
static std::vector<NativeFunction> hash_fn_registry_ = {
NativeFunction("hash32_spark", {}, DataTypeVector{boolean(), int32()}, int32(),
kResultNullNever, "hash32_spark_boolean_int32"),
NativeFunction("hash32_spark", {}, DataTypeVector{int8(), int32()}, int32(),
kResultNullNever, "hash32_spark_int8_int32"),
NativeFunction("hash32_spark", {}, DataTypeVector{int16(), int32()}, int32(),
kResultNullNever, "hash32_spark_int16_int32"),
NativeFunction("hash32_spark", {}, DataTypeVector{int32(), int32()}, int32(),
kResultNullNever, "hash32_spark_int32_int32"),
NativeFunction("hash32_spark", {}, DataTypeVector{date32(), int32()}, int32(),
kResultNullNever, "hash32_spark_date32_int32"),
NativeFunction("hash32_spark", {}, DataTypeVector{float32(), int32()}, int32(),
kResultNullNever, "hash32_spark_float32_int32"),
NativeFunction("hash64_spark", {}, DataTypeVector{int64(), int32()}, int32(),
kResultNullNever, "hash64_spark_int64_int32"),
NativeFunction("hash64_spark", {}, DataTypeVector{float64(), int32()}, int32(),
kResultNullNever, "hash64_spark_float64_int32"),
NativeFunction("hashbuf_spark", {}, DataTypeVector{utf8(), int32()}, int32(),
kResultNullNever, "hashbuf_spark_utf8_int32"),

HASH32_SAFE_NULL_NEVER_FN(hash, {}),
HASH32_SAFE_NULL_NEVER_FN(hash32, {}),
HASH32_SAFE_NULL_NEVER_FN(hash32AsDouble, {}),
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/gandiva/function_registry_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(abs, {}, float32, float32),
UNARY_SAFE_NULL_IF_NULL(abs, {}, float64, float64),

BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, float64, int32, float64),
BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, float64, int64, float64),

// decimal functions
UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128),
UNARY_SAFE_NULL_IF_NULL(ceil, {}, decimal128, decimal128),
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,29 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(),
kResultNullIfNull, "gdv_fn_castFLOAT8_utf8",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
NativeFunction("castVARCHAR", {}, DataTypeVector{int8(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int8_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{int16(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int16_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{int32(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int32_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{int64(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int64_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{float32(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_float32_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{float64(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_float64_int64",
NativeFunction::kNeedsContext),

NativeFunction("castINT", {}, DataTypeVector{binary()}, int32(), kResultNullIfNull,
"gdv_fn_castINT_varbinary",
Expand Down
109 changes: 102 additions & 7 deletions cpp/src/gandiva/precompiled/arithmetic_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
extern "C" {

#include <math.h>
#include <stdio.h>
#include <string.h>
#include <cfloat>
#include "./types.h"

// Expand inner macro for all numeric types.
Expand Down Expand Up @@ -91,6 +94,12 @@ BINARY_SYMMETRIC(bitwise_xor, int64, ^)

#undef BINARY_SYMMETRIC

FORCE_INLINE
gdv_boolean isNaN_float32(gdv_float32 val) { return isnan(val) || isinf(val); }

FORCE_INLINE
gdv_boolean isNaN_float64(gdv_float64 val) { return isnan(val) || isinf(val); }

MOD_OP(mod, int64, int32, int32)
MOD_OP(mod, int64, int64, int64)

Expand Down Expand Up @@ -188,6 +197,32 @@ NUMERIC_DATE_TYPES(COMPARE_SIX_VALUES, least, <)
#undef COMPARE_FIVE_VALUES
#undef COMPARE_SIX_VALUES

// Relational binary fns : left, right params are same, return is bool.
#define BINARY_RELATIONAL_NAN(NAME, TYPE, OP) \
FORCE_INLINE \
bool NAME##_##TYPE##_##TYPE(gdv_##TYPE left, gdv_##TYPE right) { \
const double infinity = 1.0 / 0.0; \
bool left_is_nan = isnan(left); \
bool right_is_nan = isnan(right); \
if (left_is_nan && right_is_nan) { \
return infinity OP infinity; \
} else if (left_is_nan) { \
return infinity OP right; \
} else if (right_is_nan) { \
return left OP infinity; \
} \
return left OP right; \
}

NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_NAN, equal_with_nan, ==)
NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_NAN, not_equal_with_nan, !=)
NUMERIC_DATE_TYPES(BINARY_RELATIONAL_NAN, less_than_with_nan, <)
NUMERIC_DATE_TYPES(BINARY_RELATIONAL_NAN, less_than_or_equal_to_with_nan, <=)
NUMERIC_DATE_TYPES(BINARY_RELATIONAL_NAN, greater_than_with_nan, >)
NUMERIC_DATE_TYPES(BINARY_RELATIONAL_NAN, greater_than_or_equal_to_with_nan, >=)

#undef BINARY_RELATIONAL_NAN

// cast fns : takes one param type, returns another type.
#define CAST_UNARY(NAME, IN_TYPE, OUT_TYPE) \
FORCE_INLINE \
Expand All @@ -197,12 +232,13 @@ NUMERIC_DATE_TYPES(COMPARE_SIX_VALUES, least, <)

CAST_UNARY(castBIGINT, int32, int64)
CAST_UNARY(castBIGINT, date64, int64)
CAST_UNARY(castBIGINT, float32, int64)
CAST_UNARY(castBIGINT, float64, int64)
CAST_UNARY(castINT, int8, int32)
CAST_UNARY(castINT, int16, int32)
CAST_UNARY(castINT, int64, int32)
CAST_UNARY(castINT, date32, int32)
CAST_UNARY(castINT, float32, int32)
CAST_UNARY(castINT, float64, int32)
CAST_UNARY(castBYTE, int16, int8)
CAST_UNARY(castBYTE, int32, int8)
CAST_UNARY(castBYTE, int64, int8)
CAST_UNARY(castFLOAT4, int32, float32)
CAST_UNARY(castFLOAT4, int64, float32)
CAST_UNARY(castFLOAT8, int32, float64)
Expand All @@ -211,6 +247,46 @@ CAST_UNARY(castFLOAT8, float32, float64)
CAST_UNARY(castFLOAT4, float64, float32)

#undef CAST_UNARY
#define nothing
#define PRINT(DIGSF, DIGS, FMT) PRINT_##DIGSF(DIGS, FMT)
#define PRINT_NOFMT(DIGS, FMT) int res = snprintf(char_buffer, length, FMT, in);
#define PRINT_FMT(DIGS, FMT) int res = snprintf(char_buffer, length, FMT, DIGS, in);

#define CAST_UNARY_UTF8(NAME, IN_TYPE, OUT_TYPE, FMT, DIGSF, DIGS) \
FORCE_INLINE \
const char* NAME##_##IN_TYPE##_int64(gdv_int64 context, gdv_##IN_TYPE in, \
gdv_int64 length, gdv_int32 * out_len) { \
const int32_t char_buffer_length = length; \
char char_buffer[char_buffer_length]; \
PRINT(DIGSF, DIGS, FMT) \
if (res < 0) { \
gdv_fn_context_set_error_msg(context, "Could not format the ##IN_TYPE"); \
return ""; \
} \
\
*out_len = strlen(char_buffer); \
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); \
if (ret == nullptr) { \
gdv_fn_context_set_error_msg(context, \
"Could not allocate memory for output string"); \
*out_len = 0; \
return ""; \
} \
\
memcpy(ret, char_buffer, *out_len); \
return ret; \
}

CAST_UNARY_UTF8(castVARCHAR, int8, utf8, "%d", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, int16, utf8, "%d", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, int32, utf8, "%d", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, int64, utf8, "%ld", NOFMT, nothing)
// CAST_UNARY_UTF8(castVARCHAR, float32, utf8, "%.*f", FMT, FLT_DIG)
// CAST_UNARY_UTF8(castVARCHAR, float64, utf8, "%.*f", FMT, DBL_DIG)
CAST_UNARY_UTF8(castVARCHAR, float32, utf8, "%g", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, float64, utf8, "%g", NOFMT, nothing)

#undef CAST_UNARY_UTF8

// cast float types to int types.
#define CAST_INT_FLOAT(NAME, IN_TYPE, OUT_TYPE) \
Expand Down Expand Up @@ -327,9 +403,7 @@ NUMERIC_BOOL_DATE_FUNCTION(IS_NOT_DISTINCT_FROM)
FORCE_INLINE \
gdv_##TYPE divide_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \
if (in2 == 0) { \
char const* err_msg = "divide by zero error"; \
gdv_fn_context_set_error_msg(context, err_msg); \
return 0; \
return static_cast<gdv_##TYPE>(NULL); \
} \
return static_cast<gdv_##TYPE>(in1 / in2); \
}
Expand Down Expand Up @@ -378,6 +452,27 @@ BITWISE_NOT(int32)
BITWISE_NOT(int64)

#undef BITWISE_NOT
#define SHIFT_LEFT_INT(LTYPE, RTYPE) \
FORCE_INLINE \
gdv_##LTYPE shift_left_##LTYPE##_##RTYPE(gdv_##LTYPE in1, gdv_##RTYPE in2) { \
return static_cast<gdv_##LTYPE>(in1 << in2); \
}

SHIFT_LEFT_INT(int32, int32)
SHIFT_LEFT_INT(int64, int32)

#undef SHIFT_RIGHT_INT

#define SHIFT_RIGHT_INT(LTYPE, RTYPE) \
FORCE_INLINE \
gdv_##LTYPE shift_right_##LTYPE##_##RTYPE(gdv_##LTYPE in1, gdv_##RTYPE in2) { \
return static_cast<gdv_##LTYPE>(in1 >> in2); \
}

SHIFT_RIGHT_INT(int32, int32)
SHIFT_RIGHT_INT(int64, int32)

#undef SHIFT_RIGHT_INT

#undef DATE_FUNCTION
#undef DATE_TYPES
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,18 @@ TEST(TestArithmeticOps, TestBigIntCastFloatDouble) {
EXPECT_EQ(castBIGINT_float64(-2147483647), -2147483647);
}

TEST(TestArithmeticOps, TestCastVarhcar) {
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
gdv_int32 out_len = 0;

const char* out_str = castVARCHAR_int32_int64(ctx_ptr, 88, 11L, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "88");
EXPECT_FALSE(ctx.has_error());

out_str = castVARCHAR_float64_int64(ctx_ptr, 8.712128f, 21L, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "8.712128");
EXPECT_FALSE(ctx.has_error());
}

} // namespace gandiva
18 changes: 18 additions & 0 deletions cpp/src/gandiva/precompiled/extended_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ extern "C" {

#include <math.h>
#include <stdio.h>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>

Expand Down Expand Up @@ -96,6 +97,23 @@ ABS_TYPES_UNARY(int64, uint64)
ABS_FTYPES_UNARY(float32, float32)
ABS_FTYPES_UNARY(float64, float64)

// round
#define ROUND_TYPES_UNARY(IN_TYPE1, IN_TYPE2, OUT_TYPE) \
FORCE_INLINE \
gdv_##OUT_TYPE round_##IN_TYPE1##_##IN_TYPE2(gdv_##IN_TYPE1 val, gdv_##IN_TYPE2 dp) { \
int charsNeeded = 1 + snprintf(NULL, 0, "%.*f", (int) dp, val); \
char* buffer = reinterpret_cast<char*>(malloc(charsNeeded)); \
snprintf(buffer, charsNeeded, "%.*f", (int) dp, nextafter(val, val*2)); \
double result = atof(buffer); \
free(buffer); \
return static_cast<gdv_##OUT_TYPE>(result); \
}

ROUND_TYPES_UNARY(float64, int32, float64)
ROUND_TYPES_UNARY(float64, int64, float64)

#undef ROUND_TYPES_UNARY

FORCE_INLINE
void set_error_for_logbase(int64_t execution_context, double base) {
char const* prefix = "divide by zero error with log of base";
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/gandiva/precompiled/extended_math_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,4 +407,10 @@ TEST(TestExtendedMathOps, TestBinRepresentation) {
"1000000000000000000000000000000000000000000000000000000000000000");
EXPECT_FALSE(ctx.has_error());
}
TEST(TestExtendedMathOps, TestRound) {
EXPECT_EQ(round_float64_int32(1234.56789, 4), 1234.5679);
EXPECT_EQ(round_float64_int64(1234.56789, 4), 1234.5679);
EXPECT_EQ(round_float64_int32(-1234.56789, 4), -1234.5679);
EXPECT_EQ(round_float64_int64(-1234.56789, 4), -1234.5679);
}
} // namespace gandiva
Loading

0 comments on commit 568e98d

Please sign in to comment.