Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix function mod decimal scale overflow (#2462) #2608

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion dbms/src/Functions/FunctionsArithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,42 @@ struct ModuloImpl<A,B,false>
{
using ResultType = typename NumberTraits::ResultOfModulo<A, B>::Type;

<<<<<<< HEAD
=======
template <typename To, typename From>
static make_unsigned_t<To> to_unsigned(const From & value)
{
using ReturnType = make_unsigned_t<To>;

if constexpr (is_signed_v<From>)
{
if constexpr (is_boost_number_v<ReturnType>)
{
// assert that negation of std::numeric_limits<From>::min() will not result in overflow.
// TODO: find credible source that describes numeric limits of boost multiprecision *checked* integers.
static_assert(-std::numeric_limits<From>::max() == std::numeric_limits<From>::min());
return static_cast<ReturnType>(boost::multiprecision::abs(value));
}
else
{
if (value < 0)
{
// both signed to unsigned conversion [1] and negation of unsigned integers [2] are well defined in C++.
//
// see:
// [1]: https://en.cppreference.com/w/c/language/conversion#Integer_conversions
// [2]: https://en.cppreference.com/w/cpp/language/operator_arithmetic#Unary_arithmetic_operators
return -static_cast<ReturnType>(value);
}
else
return static_cast<ReturnType>(value);
}
}
else
return static_cast<ReturnType>(value);
}

>>>>>>> c407b167c... Fix function `mod` decimal scale overflow (#2462)
template <typename Result = ResultType>
static inline Result apply(A a, B b)
{
Expand Down Expand Up @@ -1286,7 +1322,7 @@ struct DecimalBinaryOperation
static constexpr bool is_plus_minus_compare = is_plus_minus || is_compare;
static constexpr bool can_overflow = is_plus_minus || is_multiply;

static constexpr bool need_promote_type = (std::is_same_v<ResultType_, A> || std::is_same_v<ResultType_, B>) && (is_plus_minus_compare || is_division || is_multiply) ; // And is multiple / division
static constexpr bool need_promote_type = (std::is_same_v<ResultType_, A> || std::is_same_v<ResultType_, B>) && (is_plus_minus_compare || is_division || is_multiply || is_modulo) ; // And is multiple / division / modulo
static constexpr bool check_overflow = need_promote_type && std::is_same_v<ResultType_, Decimal256>; // Check if exceeds 10 * 66;

using ResultType = ResultType_;
Expand Down
237 changes: 237 additions & 0 deletions dbms/src/Functions/tests/gtest_arithmetic_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,5 +698,242 @@ try
}
CATCH

<<<<<<< HEAD
=======
TEST_F(TestBinaryArithmeticFunctions, Modulo)
try
{
const String func_name = "modulo";

using uint64_limits = std::numeric_limits<UInt64>;
using int64_limits = std::numeric_limits<Int64>;

// "{}" is similar to std::nullopt.

// integer modulo

executeFunctionWithData<UInt64, UInt64, UInt64>(__LINE__, func_name,
makeDataType<DataTypeUInt64>(), makeDataType<DataTypeUInt64>(),
{5, 3, uint64_limits::max(), 1, 0, 0, {}, 0, {}},
{3, 5, uint64_limits::max() - 1, 0, 1, 0, 0, {}, {}},
{2, 3, 1, {}, 0, {}, {}, {}, {}});
executeFunctionWithData<UInt64, Int64, UInt64>(__LINE__, func_name,
makeDataType<DataTypeUInt64>(), makeDataType<DataTypeInt64>(),
{5, 5, uint64_limits::max(), uint64_limits::max(), uint64_limits::max(), 1, 0, 0, {}, 0, {}},
{3, -3, int64_limits::max(), int64_limits::max() - 1, int64_limits::min(), 0, 1, 0, 0, {}, {}},
{2, 2, 1, 3, int64_limits::max(), {}, 0, {}, {}, {}, {}});
executeFunctionWithData<Int64, UInt64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeUInt64>(),
{5, -5, int64_limits::max(), int64_limits::min(), 1, 0, 0, {}, 0, {}},
{3, 3, 998244353, 998244353, 0, 1, 0, 0, {}, {}},
{2, -2, 466025954, -466025955, {}, 0, {}, {}, {}, {}});
executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(),
{5, -5, 5, -5, int64_limits::max(), int64_limits::min(), 1, 0, 0, {}, 0, {}},
{3, 3, -3, -3, int64_limits::min(), int64_limits::max(), 0, 1, 0, 0, {}, {}},
{2, -2, 2, -2, int64_limits::max(), -1, {}, 0, {}, {}, {}, {}});

// decimal modulo

executeFunctionWithData<DecimalField32, DecimalField32, DecimalField32>(__LINE__, func_name,
makeDataType<Decimal32>(7, 3), makeDataType<Decimal32>(7, 3),
{
DecimalField32(3300, 3), DecimalField32(-3300, 3), DecimalField32(3300, 3),
DecimalField32(-3300, 3), DecimalField32(1000, 3), {}, DecimalField32(0,3), {}
},
{
DecimalField32(1300, 3), DecimalField32(1300, 3), DecimalField32(-1300, 3),
DecimalField32(-1300, 3), DecimalField32(0, 3), DecimalField32(0, 3), {}, {}
},
{
DecimalField32(700, 3), DecimalField32(-700, 3), DecimalField32(700, 3),
DecimalField32(-700, 3), {}, {}, {}, {}
}, 7, 3);

// decimal overflow test.

// for example, 999'999'999 % 1.0000'0000 can check whether Decimal32 is doing arithmetic on Int64.
// scaling 999'999'999 (Decimal(9, 0)) to Decimal(9, 8) needs to multiply it with 1'0000'0000.
// if it uses Int32, it will overflow and get wrong result (something like 0.69325056).

#define MODULO_OVERFLOW_TESTCASE(Decimal, DecimalField, precision) \
{ \
auto & builder = DecimalMaxValue::instance(); \
auto max_scale = std::min(decimal_max_scale, static_cast<ScaleType>(precision) - 1); \
auto exp10_x = static_cast<Decimal::NativeType>(builder.Get(max_scale)) + 1; /* exp10_x: 10^x */ \
auto decimal_max = exp10_x * 10 - 1; \
auto zero = static_cast<Decimal::NativeType>(0); /* for Int256 */ \
executeFunctionWithData<DecimalField, DecimalField, DecimalField>(__LINE__, func_name, \
makeDataType<Decimal>((precision), 0), makeDataType<Decimal>((precision), max_scale), \
{DecimalField(decimal_max, 0)}, {DecimalField(exp10_x, max_scale)}, {DecimalField(zero, max_scale)}, (precision), max_scale); \
executeFunctionWithData<DecimalField, DecimalField, DecimalField>(__LINE__, func_name, \
makeDataType<Decimal>((precision), max_scale), makeDataType<Decimal>((precision), 0), \
{DecimalField(exp10_x, max_scale)}, {DecimalField(decimal_max, 0)}, {DecimalField(exp10_x, max_scale)}, (precision), max_scale); \
}

MODULO_OVERFLOW_TESTCASE(Decimal32, DecimalField32, 9);
MODULO_OVERFLOW_TESTCASE(Decimal64, DecimalField64, 18);
MODULO_OVERFLOW_TESTCASE(Decimal128, DecimalField128, 38);
MODULO_OVERFLOW_TESTCASE(Decimal256, DecimalField256, 65);

#undef MODULO_OVERFLOW_TESTCASE

Int128 large_number_1 = static_cast<Int128>(std::numeric_limits<UInt64>::max()) * 100000;
executeFunctionWithData<DecimalField128, DecimalField128, DecimalField128>(__LINE__, func_name,
makeDataType<Decimal128>(38, 5), makeDataType<Decimal128>(38, 5),
{DecimalField128(large_number_1, 5), DecimalField128(large_number_1, 5), DecimalField128(large_number_1, 5)},
{DecimalField128(100000, 5), DecimalField128(large_number_1 - 1, 5), DecimalField128(large_number_1 / 2 + 1, 5)},
{DecimalField128(large_number_1 % 100000, 5), DecimalField128(1, 5), DecimalField128(large_number_1 / 2 - 1, 5)}, 38, 5);

Int256 large_number_2 = static_cast<Int256>(large_number_1) * large_number_1;
executeFunctionWithData<DecimalField256, DecimalField256, DecimalField256>(__LINE__, func_name,
makeDataType<Decimal256>(65, 5), makeDataType<Decimal256>(65, 5),
{DecimalField256(large_number_2, 5), DecimalField256(large_number_2, 5), DecimalField256(large_number_2, 5)},
{DecimalField256(static_cast<Int256>(100000), 5), DecimalField256(large_number_2 - 1, 5), DecimalField256(large_number_2 / 2 + 1, 5)},
{DecimalField256(large_number_2 % 100000, 5), DecimalField256(static_cast<Int256>(1), 5), DecimalField256(large_number_2 / 2 - 1, 5)}, 65, 5);

// Int64 has a precision of 20, which is larger than Decimal64.
executeFunctionWithData<DecimalField32, Int64, DecimalField128>(__LINE__, func_name,
makeDataType<Decimal32>(7, 3), makeDataType<DataTypeInt64>(),
{DecimalField32(3300, 3), DecimalField32(3300, 3), {}}, {1, 0, {}},
{DecimalField128(300, 3), {}, {}}, 20, 3);

executeFunctionWithData<DecimalField32, DecimalField64, DecimalField64>(__LINE__, func_name,
makeDataType<Decimal32>(7, 5), makeDataType<Decimal64>(15, 3),
{DecimalField32(3223456, 5)}, {DecimalField64(9244, 3)}, {DecimalField64(450256, 5)}, 15, 5);

// real modulo

executeFunctionWithData<Float64, Float64, Float64>(__LINE__, func_name,
makeDataType<DataTypeFloat64>(), makeDataType<DataTypeFloat64>(),
{1.3, -1.3, 1.3, -1.3, 3.3, -3.3, 3.3, -3.3, 12.34, 0.0, 0.0, 0.0, {}, {}},
{1.1, 1.1, -1.1, -1.1, 1.1, 1.1, -1.1, -1.1, 0.0, 12.34, 0.0, {}, 0.0, {}},
{
0.19999999999999996, -0.19999999999999996, 0.19999999999999996, -0.19999999999999996,
1.0999999999999996, -1.0999999999999996, 1.0999999999999996, -1.0999999999999996,
{}, 0.0, {}, {}, {}, {}
});
executeFunctionWithData<Float64, Int64, Float64>(__LINE__, func_name,
makeDataType<DataTypeFloat64>(), makeDataType<DataTypeInt64>(),
{1.55, 1.55, {}, 0.0, {}}, {-1, 0, 0, {}, {}}, {0.55, {}, {}, {}, {}});
executeFunctionWithData<DecimalField32, Float64, Float64>(__LINE__, func_name,
makeDataType<Decimal32>(7, 3), makeDataType<DataTypeFloat64>(),
{DecimalField32(1250, 3), DecimalField32(1250, 3), {}, DecimalField32(0, 3), {}},
{1.0, 0.0, 0.0, {}, {}}, {0.25, {}, {}, {}, {}});

// const-vector modulo

executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(),
{3}, {0, 1, 2, 3, 4, 5, 6}, {{}, 0, 1, 0, 3, 3, 3});

// vector-const modulo

executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(),
{0, 1, 2, 3}, {0}, {{}, {}, {}, {}});
executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(),
{0, 1, 2, 3, 4, 5, 6}, {3}, {0, 1, 2, 0, 1, 2, 0});

// const-const modulo

executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(), {5}, {-3}, {2});
executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(), {0}, {0}, {{}});
executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(), {{}}, {0}, {{}});
executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(), {0}, {{}}, {{}});
executeFunctionWithData<Int64, Int64, Int64>(__LINE__, func_name,
makeDataType<DataTypeInt64>(), makeDataType<DataTypeInt64>(), {{}}, {{}}, {{}});
}
CATCH

TEST_F(TestBinaryArithmeticFunctions, ModuloExtra)
try
{
std::unordered_map<String, DataTypePtr> data_type_map =
{
{"Int64", makeDataType<DataTypeInt64>()},
{"UInt64", makeDataType<DataTypeUInt64>()},
{"Float64", makeDataType<DataTypeFloat64>()},
{"DecimalField32", makeDataType<DataTypeDecimal32>(9, 3)},
{"DecimalField64", makeDataType<DataTypeDecimal64>(18, 6)},
{"DecimalField128", makeDataType<DataTypeDecimal128>(38, 10)},
{"DecimalField256", makeDataType<DataTypeDecimal256>(65, 20)},
};

#define MODULO_TESTCASE(Left, Right, Result, precision, left_scale, right_scale, result_scale) \
executeFunctionWithData<Left, Right, Result>(__LINE__, "modulo", \
data_type_map[#Left], data_type_map[#Right], \
{GetValue<Left, left_scale>::Max(), GetValue<Left, left_scale>::Zero(),GetValue<Left, left_scale>::One(), GetValue<Left, left_scale>::Zero(), {}, {}, {}}, \
{GetValue<Right, right_scale>::Zero(), GetValue<Right, right_scale>::Max(), GetValue<Right, right_scale>::One(), {}, GetValue<Right, right_scale>::Zero(), GetValue<Right, right_scale>::Max(), {}}, \
{{}, GetValue<Result, result_scale>::Zero(), GetValue<Result, result_scale>::Zero(), {}, {}, {}, {}}, \
(precision), (result_scale));

MODULO_TESTCASE(Int64, Int64, Int64, 0, 0, 0, 0);
MODULO_TESTCASE(Int64, UInt64, Int64, 0, 0, 0, 0);
MODULO_TESTCASE(Int64, Float64, Float64, 0, 0, 0, 0);
MODULO_TESTCASE(Int64, DecimalField32, DecimalField128, 20, 0, 3, 3);
MODULO_TESTCASE(Int64, DecimalField64, DecimalField128, 20, 0, 6, 6);
MODULO_TESTCASE(Int64, DecimalField128, DecimalField128, 38, 0, 10, 10);
MODULO_TESTCASE(Int64, DecimalField256, DecimalField256, 65, 0, 20, 20);

MODULO_TESTCASE(UInt64, Int64, UInt64, 0, 0, 0, 0);
MODULO_TESTCASE(UInt64, UInt64, UInt64, 0, 0, 0, 0);
MODULO_TESTCASE(UInt64, Float64, Float64, 0, 0, 0, 0);
MODULO_TESTCASE(UInt64, DecimalField32, DecimalField128, 20, 0, 3, 3);
MODULO_TESTCASE(UInt64, DecimalField64, DecimalField128, 20, 0, 6, 6);
MODULO_TESTCASE(UInt64, DecimalField128, DecimalField128, 38, 0, 10, 10);
MODULO_TESTCASE(UInt64, DecimalField256, DecimalField256, 65, 0, 20, 20);

MODULO_TESTCASE(Float64, Int64, Float64, 0, 0, 0, 0);
MODULO_TESTCASE(Float64, UInt64, Float64, 0, 0, 0, 0);
MODULO_TESTCASE(Float64, Float64, Float64, 0, 0, 0, 0);
MODULO_TESTCASE(Float64, DecimalField32, Float64, 0, 0, 3, 0);
MODULO_TESTCASE(Float64, DecimalField64, Float64, 0, 0, 6, 0);
MODULO_TESTCASE(Float64, DecimalField128, Float64, 0, 0, 10, 0);
MODULO_TESTCASE(Float64, DecimalField256, Float64, 0, 0, 20, 0);

MODULO_TESTCASE(DecimalField32, Int64, DecimalField128, 20, 3, 0, 3);
MODULO_TESTCASE(DecimalField32, UInt64, DecimalField128, 20, 3, 0, 3);
MODULO_TESTCASE(DecimalField32, Float64, Float64, 0, 3, 0, 0);
MODULO_TESTCASE(DecimalField32, DecimalField32, DecimalField32, 9, 3, 3, 3);
MODULO_TESTCASE(DecimalField32, DecimalField64, DecimalField64, 18, 3, 6, 6);
MODULO_TESTCASE(DecimalField32, DecimalField128, DecimalField128, 38, 3, 10, 10);
MODULO_TESTCASE(DecimalField32, DecimalField256, DecimalField256, 65, 3, 20, 20);

MODULO_TESTCASE(DecimalField64, Int64, DecimalField128, 20, 6, 0, 6);
MODULO_TESTCASE(DecimalField64, UInt64, DecimalField128, 20, 6, 0, 6);
MODULO_TESTCASE(DecimalField64, Float64, Float64, 0, 6, 0, 0);
MODULO_TESTCASE(DecimalField64, DecimalField32, DecimalField64, 18, 6, 3, 6);
MODULO_TESTCASE(DecimalField64, DecimalField64, DecimalField64, 18, 6, 6, 6);
MODULO_TESTCASE(DecimalField64, DecimalField128, DecimalField128, 38, 6, 10, 10);
MODULO_TESTCASE(DecimalField64, DecimalField256, DecimalField256, 65, 6, 20, 20);

MODULO_TESTCASE(DecimalField128, Int64, DecimalField128, 38, 10, 0, 10);
MODULO_TESTCASE(DecimalField128, UInt64, DecimalField128, 38, 10, 0, 10);
MODULO_TESTCASE(DecimalField128, Float64, Float64, 0, 10, 0, 0);
MODULO_TESTCASE(DecimalField128, DecimalField32, DecimalField128, 38, 10, 3, 10);
MODULO_TESTCASE(DecimalField128, DecimalField64, DecimalField128, 38, 10, 6, 10);
MODULO_TESTCASE(DecimalField128, DecimalField128, DecimalField128, 38, 10, 10, 10);
MODULO_TESTCASE(DecimalField128, DecimalField256, DecimalField256, 65, 10, 20, 20);

MODULO_TESTCASE(DecimalField256, Int64, DecimalField256, 65, 20, 0, 20);
MODULO_TESTCASE(DecimalField256, UInt64, DecimalField256, 65, 20, 0, 20);
MODULO_TESTCASE(DecimalField256, Float64, Float64, 0, 20, 0, 0);
MODULO_TESTCASE(DecimalField256, DecimalField32, DecimalField256, 65, 20, 3, 20);
MODULO_TESTCASE(DecimalField256, DecimalField64, DecimalField256, 65, 20, 6, 20);
MODULO_TESTCASE(DecimalField256, DecimalField128, DecimalField256, 65, 20, 10, 20);
MODULO_TESTCASE(DecimalField256, DecimalField256, DecimalField256, 65, 20, 20, 20);

#undef MODULO_TESTCASE
}
CATCH


>>>>>>> c407b167c... Fix function `mod` decimal scale overflow (#2462)
} // namespace tests
} // namespace DB