diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc index 8f99d9de4536e..b8ee15cfe6b1f 100644 --- a/cpp/src/gandiva/function_registry_arithmetic.cc +++ b/cpp/src/gandiva/function_registry_arithmetic.cc @@ -91,9 +91,12 @@ std::vector GetArithmeticFunctionRegistry() { BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}), BINARY_SYMMETRIC_FN(multiply, {}), NUMERIC_TYPES(BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL, divide, {}), - BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int32, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int8, int8, int8), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int16, int16, int16), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int32, int32, int32), BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int64, int64), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float32), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float64), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int32), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int64), diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/cpp/src/gandiva/precompiled/arithmetic_ops.cc index 1ffd0a97dad6b..f6b727231f80c 100644 --- a/cpp/src/gandiva/precompiled/arithmetic_ops.cc +++ b/cpp/src/gandiva/precompiled/arithmetic_ops.cc @@ -95,11 +95,20 @@ gdv_boolean isNaN_float32(gdv_float32 val) { return isnan(val) || isinf(val); } FORCE_INLINE gdv_boolean isNaN_float64(gdv_float64 val) { return isnan(val) || isinf(val); } -MOD_OP(mod, int64, int32, int32) +MOD_OP(mod, int32, int32, int32) MOD_OP(mod, int64, int64, int64) #undef MOD_OP +gdv_float32 mod_float32_float32(int64_t context, gdv_float32 x, gdv_float32 y) { + if (y == 0.0) { + // char const* err_msg = "divide by zero error"; + // gdv_fn_context_set_error_msg(context, err_msg); + return 0.0; + } + return fmod(x, y); +} + gdv_float64 mod_float64_float64(int64_t context, gdv_float64 x, gdv_float64 y) { if (y == 0.0) { // Setting error msg can cause unexpected runtime exception. diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc index 23b290d6b91fe..c900eddcb1935 100644 --- a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc +++ b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc @@ -36,7 +36,7 @@ TEST(TestArithmeticOps, TestIsDistinctFrom) { TEST(TestArithmeticOps, TestMod) { gandiva::ExecutionContext context; - EXPECT_EQ(mod_int64_int32(10, 0), 10); + EXPECT_EQ(mod_int32_int32(10, 0), 10); const double acceptable_abs_error = 0.00000000001; // 1e-10 diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index c6d6ea6c9ea40..0f7ce8afea44f 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -144,7 +144,10 @@ double months_between_timestamp_timestamp(gdv_uint64, gdv_uint64); gdv_int32 mem_compare(const char* left, gdv_int32 left_len, const char* right, gdv_int32 right_len); -gdv_int32 mod_int64_int32(gdv_int64 left, gdv_int32 right); +gdv_int32 mod_int8_int8(gdv_int8 left, gdv_int8 right); +gdv_int32 mod_int16_int16(gdv_int16 left, gdv_int16 right); +gdv_int32 mod_int32_int32(gdv_int32 left, gdv_int32 right); +gdv_float32 mod_float32_float32(gdv_int64 context, gdv_float32 left, gdv_float32 right); gdv_float64 mod_float64_float64(gdv_int64 context, gdv_float64 left, gdv_float64 right); gdv_int8 pmod_int8_int8(gdv_int8 in1, bool in1_valid, gdv_int8 in2, bool in2_valid, bool* out_valid);