From acd9fa28e988e124264ecbadd5d84b612b4e7f62 Mon Sep 17 00:00:00 2001 From: Xue Zhenliang Date: Thu, 22 Jul 2021 22:14:20 +0800 Subject: [PATCH] cherry pick #2462 to release-5.0 Signed-off-by: ti-srebot --- dbms/src/Functions/FunctionsArithmetic.h | 38 ++- .../tests/gtest_arithmetic_functions.cpp | 237 ++++++++++++++++++ 2 files changed, 274 insertions(+), 1 deletion(-) diff --git a/dbms/src/Functions/FunctionsArithmetic.h b/dbms/src/Functions/FunctionsArithmetic.h index 7f8b0930553..3d55b70e64e 100644 --- a/dbms/src/Functions/FunctionsArithmetic.h +++ b/dbms/src/Functions/FunctionsArithmetic.h @@ -525,6 +525,42 @@ struct ModuloImpl { using ResultType = typename NumberTraits::ResultOfModulo::Type; +<<<<<<< HEAD +======= + template + static make_unsigned_t to_unsigned(const From & value) + { + using ReturnType = make_unsigned_t; + + if constexpr (is_signed_v) + { + if constexpr (is_boost_number_v) + { + // assert that negation of std::numeric_limits::min() will not result in overflow. + // TODO: find credible source that describes numeric limits of boost multiprecision *checked* integers. + static_assert(-std::numeric_limits::max() == std::numeric_limits::min()); + return static_cast(boost::multiprecision::abs(value)); + } + else + { + if (value < 0) + { + // both signed to unsigned conversion [1] and negation of unsigned integers [2] are well defined in C++. + // + // see: + // [1]: https://en.cppreference.com/w/c/language/conversion#Integer_conversions + // [2]: https://en.cppreference.com/w/cpp/language/operator_arithmetic#Unary_arithmetic_operators + return -static_cast(value); + } + else + return static_cast(value); + } + } + else + return static_cast(value); + } + +>>>>>>> c407b167c... Fix function `mod` decimal scale overflow (#2462) template static inline Result apply(A a, B b) { @@ -1286,7 +1322,7 @@ struct DecimalBinaryOperation static constexpr bool is_plus_minus_compare = is_plus_minus || is_compare; static constexpr bool can_overflow = is_plus_minus || is_multiply; - static constexpr bool need_promote_type = (std::is_same_v || std::is_same_v) && (is_plus_minus_compare || is_division || is_multiply) ; // And is multiple / division + static constexpr bool need_promote_type = (std::is_same_v || std::is_same_v) && (is_plus_minus_compare || is_division || is_multiply || is_modulo) ; // And is multiple / division / modulo static constexpr bool check_overflow = need_promote_type && std::is_same_v; // Check if exceeds 10 * 66; using ResultType = ResultType_; diff --git a/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp b/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp index fe4c1c820be..fe1f90974bb 100644 --- a/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp +++ b/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp @@ -698,5 +698,242 @@ try } CATCH +<<<<<<< HEAD +======= +TEST_F(TestBinaryArithmeticFunctions, Modulo) +try +{ + const String func_name = "modulo"; + + using uint64_limits = std::numeric_limits; + using int64_limits = std::numeric_limits; + + // "{}" is similar to std::nullopt. + + // integer modulo + + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {5, 3, uint64_limits::max(), 1, 0, 0, {}, 0, {}}, + {3, 5, uint64_limits::max() - 1, 0, 1, 0, 0, {}, {}}, + {2, 3, 1, {}, 0, {}, {}, {}, {}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {5, 5, uint64_limits::max(), uint64_limits::max(), uint64_limits::max(), 1, 0, 0, {}, 0, {}}, + {3, -3, int64_limits::max(), int64_limits::max() - 1, int64_limits::min(), 0, 1, 0, 0, {}, {}}, + {2, 2, 1, 3, int64_limits::max(), {}, 0, {}, {}, {}, {}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {5, -5, int64_limits::max(), int64_limits::min(), 1, 0, 0, {}, 0, {}}, + {3, 3, 998244353, 998244353, 0, 1, 0, 0, {}, {}}, + {2, -2, 466025954, -466025955, {}, 0, {}, {}, {}, {}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {5, -5, 5, -5, int64_limits::max(), int64_limits::min(), 1, 0, 0, {}, 0, {}}, + {3, 3, -3, -3, int64_limits::min(), int64_limits::max(), 0, 1, 0, 0, {}, {}}, + {2, -2, 2, -2, int64_limits::max(), -1, {}, 0, {}, {}, {}, {}}); + + // decimal modulo + + executeFunctionWithData(__LINE__, func_name, + makeDataType(7, 3), makeDataType(7, 3), + { + DecimalField32(3300, 3), DecimalField32(-3300, 3), DecimalField32(3300, 3), + DecimalField32(-3300, 3), DecimalField32(1000, 3), {}, DecimalField32(0,3), {} + }, + { + DecimalField32(1300, 3), DecimalField32(1300, 3), DecimalField32(-1300, 3), + DecimalField32(-1300, 3), DecimalField32(0, 3), DecimalField32(0, 3), {}, {} + }, + { + DecimalField32(700, 3), DecimalField32(-700, 3), DecimalField32(700, 3), + DecimalField32(-700, 3), {}, {}, {}, {} + }, 7, 3); + + // decimal overflow test. + + // for example, 999'999'999 % 1.0000'0000 can check whether Decimal32 is doing arithmetic on Int64. + // scaling 999'999'999 (Decimal(9, 0)) to Decimal(9, 8) needs to multiply it with 1'0000'0000. + // if it uses Int32, it will overflow and get wrong result (something like 0.69325056). + +#define MODULO_OVERFLOW_TESTCASE(Decimal, DecimalField, precision) \ + { \ + auto & builder = DecimalMaxValue::instance(); \ + auto max_scale = std::min(decimal_max_scale, static_cast(precision) - 1); \ + auto exp10_x = static_cast(builder.Get(max_scale)) + 1; /* exp10_x: 10^x */ \ + auto decimal_max = exp10_x * 10 - 1; \ + auto zero = static_cast(0); /* for Int256 */ \ + executeFunctionWithData(__LINE__, func_name, \ + makeDataType((precision), 0), makeDataType((precision), max_scale), \ + {DecimalField(decimal_max, 0)}, {DecimalField(exp10_x, max_scale)}, {DecimalField(zero, max_scale)}, (precision), max_scale); \ + executeFunctionWithData(__LINE__, func_name, \ + makeDataType((precision), max_scale), makeDataType((precision), 0), \ + {DecimalField(exp10_x, max_scale)}, {DecimalField(decimal_max, 0)}, {DecimalField(exp10_x, max_scale)}, (precision), max_scale); \ + } + + MODULO_OVERFLOW_TESTCASE(Decimal32, DecimalField32, 9); + MODULO_OVERFLOW_TESTCASE(Decimal64, DecimalField64, 18); + MODULO_OVERFLOW_TESTCASE(Decimal128, DecimalField128, 38); + MODULO_OVERFLOW_TESTCASE(Decimal256, DecimalField256, 65); + +#undef MODULO_OVERFLOW_TESTCASE + + Int128 large_number_1 = static_cast(std::numeric_limits::max()) * 100000; + executeFunctionWithData(__LINE__, func_name, + makeDataType(38, 5), makeDataType(38, 5), + {DecimalField128(large_number_1, 5), DecimalField128(large_number_1, 5), DecimalField128(large_number_1, 5)}, + {DecimalField128(100000, 5), DecimalField128(large_number_1 - 1, 5), DecimalField128(large_number_1 / 2 + 1, 5)}, + {DecimalField128(large_number_1 % 100000, 5), DecimalField128(1, 5), DecimalField128(large_number_1 / 2 - 1, 5)}, 38, 5); + + Int256 large_number_2 = static_cast(large_number_1) * large_number_1; + executeFunctionWithData(__LINE__, func_name, + makeDataType(65, 5), makeDataType(65, 5), + {DecimalField256(large_number_2, 5), DecimalField256(large_number_2, 5), DecimalField256(large_number_2, 5)}, + {DecimalField256(static_cast(100000), 5), DecimalField256(large_number_2 - 1, 5), DecimalField256(large_number_2 / 2 + 1, 5)}, + {DecimalField256(large_number_2 % 100000, 5), DecimalField256(static_cast(1), 5), DecimalField256(large_number_2 / 2 - 1, 5)}, 65, 5); + + // Int64 has a precision of 20, which is larger than Decimal64. + executeFunctionWithData(__LINE__, func_name, + makeDataType(7, 3), makeDataType(), + {DecimalField32(3300, 3), DecimalField32(3300, 3), {}}, {1, 0, {}}, + {DecimalField128(300, 3), {}, {}}, 20, 3); + + executeFunctionWithData(__LINE__, func_name, + makeDataType(7, 5), makeDataType(15, 3), + {DecimalField32(3223456, 5)}, {DecimalField64(9244, 3)}, {DecimalField64(450256, 5)}, 15, 5); + + // real modulo + + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {1.3, -1.3, 1.3, -1.3, 3.3, -3.3, 3.3, -3.3, 12.34, 0.0, 0.0, 0.0, {}, {}}, + {1.1, 1.1, -1.1, -1.1, 1.1, 1.1, -1.1, -1.1, 0.0, 12.34, 0.0, {}, 0.0, {}}, + { + 0.19999999999999996, -0.19999999999999996, 0.19999999999999996, -0.19999999999999996, + 1.0999999999999996, -1.0999999999999996, 1.0999999999999996, -1.0999999999999996, + {}, 0.0, {}, {}, {}, {} + }); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {1.55, 1.55, {}, 0.0, {}}, {-1, 0, 0, {}, {}}, {0.55, {}, {}, {}, {}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(7, 3), makeDataType(), + {DecimalField32(1250, 3), DecimalField32(1250, 3), {}, DecimalField32(0, 3), {}}, + {1.0, 0.0, 0.0, {}, {}}, {0.25, {}, {}, {}, {}}); + + // const-vector modulo + + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {3}, {0, 1, 2, 3, 4, 5, 6}, {{}, 0, 1, 0, 3, 3, 3}); + + // vector-const modulo + + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {0, 1, 2, 3}, {0}, {{}, {}, {}, {}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), + {0, 1, 2, 3, 4, 5, 6}, {3}, {0, 1, 2, 0, 1, 2, 0}); + + // const-const modulo + + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), {5}, {-3}, {2}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), {0}, {0}, {{}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), {{}}, {0}, {{}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), {0}, {{}}, {{}}); + executeFunctionWithData(__LINE__, func_name, + makeDataType(), makeDataType(), {{}}, {{}}, {{}}); +} +CATCH + +TEST_F(TestBinaryArithmeticFunctions, ModuloExtra) +try +{ + std::unordered_map data_type_map = + { + {"Int64", makeDataType()}, + {"UInt64", makeDataType()}, + {"Float64", makeDataType()}, + {"DecimalField32", makeDataType(9, 3)}, + {"DecimalField64", makeDataType(18, 6)}, + {"DecimalField128", makeDataType(38, 10)}, + {"DecimalField256", makeDataType(65, 20)}, + }; + +#define MODULO_TESTCASE(Left, Right, Result, precision, left_scale, right_scale, result_scale) \ + executeFunctionWithData(__LINE__, "modulo", \ + data_type_map[#Left], data_type_map[#Right], \ + {GetValue::Max(), GetValue::Zero(),GetValue::One(), GetValue::Zero(), {}, {}, {}}, \ + {GetValue::Zero(), GetValue::Max(), GetValue::One(), {}, GetValue::Zero(), GetValue::Max(), {}}, \ + {{}, GetValue::Zero(), GetValue::Zero(), {}, {}, {}, {}}, \ + (precision), (result_scale)); + + MODULO_TESTCASE(Int64, Int64, Int64, 0, 0, 0, 0); + MODULO_TESTCASE(Int64, UInt64, Int64, 0, 0, 0, 0); + MODULO_TESTCASE(Int64, Float64, Float64, 0, 0, 0, 0); + MODULO_TESTCASE(Int64, DecimalField32, DecimalField128, 20, 0, 3, 3); + MODULO_TESTCASE(Int64, DecimalField64, DecimalField128, 20, 0, 6, 6); + MODULO_TESTCASE(Int64, DecimalField128, DecimalField128, 38, 0, 10, 10); + MODULO_TESTCASE(Int64, DecimalField256, DecimalField256, 65, 0, 20, 20); + + MODULO_TESTCASE(UInt64, Int64, UInt64, 0, 0, 0, 0); + MODULO_TESTCASE(UInt64, UInt64, UInt64, 0, 0, 0, 0); + MODULO_TESTCASE(UInt64, Float64, Float64, 0, 0, 0, 0); + MODULO_TESTCASE(UInt64, DecimalField32, DecimalField128, 20, 0, 3, 3); + MODULO_TESTCASE(UInt64, DecimalField64, DecimalField128, 20, 0, 6, 6); + MODULO_TESTCASE(UInt64, DecimalField128, DecimalField128, 38, 0, 10, 10); + MODULO_TESTCASE(UInt64, DecimalField256, DecimalField256, 65, 0, 20, 20); + + MODULO_TESTCASE(Float64, Int64, Float64, 0, 0, 0, 0); + MODULO_TESTCASE(Float64, UInt64, Float64, 0, 0, 0, 0); + MODULO_TESTCASE(Float64, Float64, Float64, 0, 0, 0, 0); + MODULO_TESTCASE(Float64, DecimalField32, Float64, 0, 0, 3, 0); + MODULO_TESTCASE(Float64, DecimalField64, Float64, 0, 0, 6, 0); + MODULO_TESTCASE(Float64, DecimalField128, Float64, 0, 0, 10, 0); + MODULO_TESTCASE(Float64, DecimalField256, Float64, 0, 0, 20, 0); + + MODULO_TESTCASE(DecimalField32, Int64, DecimalField128, 20, 3, 0, 3); + MODULO_TESTCASE(DecimalField32, UInt64, DecimalField128, 20, 3, 0, 3); + MODULO_TESTCASE(DecimalField32, Float64, Float64, 0, 3, 0, 0); + MODULO_TESTCASE(DecimalField32, DecimalField32, DecimalField32, 9, 3, 3, 3); + MODULO_TESTCASE(DecimalField32, DecimalField64, DecimalField64, 18, 3, 6, 6); + MODULO_TESTCASE(DecimalField32, DecimalField128, DecimalField128, 38, 3, 10, 10); + MODULO_TESTCASE(DecimalField32, DecimalField256, DecimalField256, 65, 3, 20, 20); + + MODULO_TESTCASE(DecimalField64, Int64, DecimalField128, 20, 6, 0, 6); + MODULO_TESTCASE(DecimalField64, UInt64, DecimalField128, 20, 6, 0, 6); + MODULO_TESTCASE(DecimalField64, Float64, Float64, 0, 6, 0, 0); + MODULO_TESTCASE(DecimalField64, DecimalField32, DecimalField64, 18, 6, 3, 6); + MODULO_TESTCASE(DecimalField64, DecimalField64, DecimalField64, 18, 6, 6, 6); + MODULO_TESTCASE(DecimalField64, DecimalField128, DecimalField128, 38, 6, 10, 10); + MODULO_TESTCASE(DecimalField64, DecimalField256, DecimalField256, 65, 6, 20, 20); + + MODULO_TESTCASE(DecimalField128, Int64, DecimalField128, 38, 10, 0, 10); + MODULO_TESTCASE(DecimalField128, UInt64, DecimalField128, 38, 10, 0, 10); + MODULO_TESTCASE(DecimalField128, Float64, Float64, 0, 10, 0, 0); + MODULO_TESTCASE(DecimalField128, DecimalField32, DecimalField128, 38, 10, 3, 10); + MODULO_TESTCASE(DecimalField128, DecimalField64, DecimalField128, 38, 10, 6, 10); + MODULO_TESTCASE(DecimalField128, DecimalField128, DecimalField128, 38, 10, 10, 10); + MODULO_TESTCASE(DecimalField128, DecimalField256, DecimalField256, 65, 10, 20, 20); + + MODULO_TESTCASE(DecimalField256, Int64, DecimalField256, 65, 20, 0, 20); + MODULO_TESTCASE(DecimalField256, UInt64, DecimalField256, 65, 20, 0, 20); + MODULO_TESTCASE(DecimalField256, Float64, Float64, 0, 20, 0, 0); + MODULO_TESTCASE(DecimalField256, DecimalField32, DecimalField256, 65, 20, 3, 20); + MODULO_TESTCASE(DecimalField256, DecimalField64, DecimalField256, 65, 20, 6, 20); + MODULO_TESTCASE(DecimalField256, DecimalField128, DecimalField256, 65, 20, 10, 20); + MODULO_TESTCASE(DecimalField256, DecimalField256, DecimalField256, 65, 20, 20, 20); + +#undef MODULO_TESTCASE +} +CATCH + + +>>>>>>> c407b167c... Fix function `mod` decimal scale overflow (#2462) } // namespace tests } // namespace DB