diff --git a/dbms/src/Common/toSafeUnsigned.h b/dbms/src/Common/toSafeUnsigned.h new file mode 100644 index 00000000000..8dd4517daed --- /dev/null +++ b/dbms/src/Common/toSafeUnsigned.h @@ -0,0 +1,48 @@ +#pragma once + +#include + +namespace DB +{ + +// toSafeUnsigned evaluates absolute value of argument `value` and cast it to unsigned type of `To`. +// it guarantees that no undefined behavior will occur and exact result can be represented by unsigned `To`. +template +constexpr make_unsigned_t toSafeUnsigned(const From & value) +{ + static_assert(is_integer_v, "type From must be integral"); + static_assert(is_integer_v, "type To must be integral"); + + using ReturnType = make_unsigned_t; + + static_assert(actual_size_v >= actual_size_v, "type unsigned To can't hold all values of type From"); + + if constexpr (is_signed_v) + { + if constexpr (is_boost_number_v) + { + // assert that negation of std::numeric_limits::min() will not result in overflow. + // TODO: find credible source that describes numeric limits of boost multiprecision *checked* integers. + static_assert(-std::numeric_limits::max() == std::numeric_limits::min()); + return static_cast(boost::multiprecision::abs(value)); + } + else + { + if (value < 0) + { + // both signed to unsigned conversion [1] and negation of unsigned integers [2] are well defined in C++. + // + // see: + // [1]: https://en.cppreference.com/w/c/language/conversion#Integer_conversions + // [2]: https://en.cppreference.com/w/cpp/language/operator_arithmetic#Unary_arithmetic_operators + return -static_cast(value); + } + else + return static_cast(value); + } + } + else + return static_cast(value); +} + +} // namespace DB diff --git a/dbms/src/DataTypes/DataTypesNumber.h b/dbms/src/DataTypes/DataTypesNumber.h index 7fbbe1e96dc..cfa32bea159 100644 --- a/dbms/src/DataTypes/DataTypesNumber.h +++ b/dbms/src/DataTypes/DataTypesNumber.h @@ -19,6 +19,7 @@ class DataTypeNumber final : public DataTypeNumberBase bool canBeUsedInBooleanContext() const override { return true; } bool isNumber() const override { return true; } bool isInteger() const override { return std::is_integral_v; } + bool isFloatingPoint() const override { return std::is_floating_point_v; } bool canBeInsideNullable() const override { return true; } public: diff --git a/dbms/src/DataTypes/IDataType.h b/dbms/src/DataTypes/IDataType.h index 3caf65a07cf..e1deb90cf7e 100644 --- a/dbms/src/DataTypes/IDataType.h +++ b/dbms/src/DataTypes/IDataType.h @@ -388,6 +388,10 @@ class IDataType : private boost::noncopyable virtual bool isInteger() const { return false; }; virtual bool isUnsignedInteger() const { return false; }; + /** Floating point values. Not Nullable. Not Enums. Not Date/DateTime. + */ + virtual bool isFloatingPoint() const { return false; } + /** Date, DateTime, MyDate, MyDateTime. Not Nullable. */ virtual bool isDateOrDateTime() const { return false; }; @@ -475,4 +479,3 @@ class IDataType : private boost::noncopyable } - diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 0a1d344c919..ddf4a50f4bb 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -336,6 +336,27 @@ static String buildBitwiseFunction(DAGExpressionAnalyzer * analyzer, const tipb: return analyzer->applyFunction(func_name, argument_names, actions, nullptr); } +static String buildRoundFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions) +{ + // ROUND(x) -> ROUND(x, 0) + + if (expr.children_size() != 1) + throw TiFlashException("Invalid arguments of ROUND function", Errors::Coprocessor::BadRequest); + + + auto input_arg_name = analyzer->getActions(expr.children(0), actions); + + auto const_zero = tipb::Expr(); + constructInt64LiteralTiExpr(const_zero, 0); + auto const_zero_arg_name = analyzer->getActions(const_zero, actions); + + Names argument_names; + argument_names.push_back(std::move(input_arg_name)); + argument_names.push_back(std::move(const_zero_arg_name)); + + return analyzer->applyFunction("tidbRoundWithFrac", argument_names, actions, getCollatorFromExpr(expr)); +} + static String buildFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions) { const String & func_name = getFunctionName(expr); @@ -354,7 +375,8 @@ static std::unordered_map}, {"date_sub", buildDateAddOrSubFunction}}); + {"date_add", buildDateAddOrSubFunction}, {"date_sub", buildDateAddOrSubFunction}, + {"tidbRound", buildRoundFunction}}); DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector && source_columns_, const Context & context_) : source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0) diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 57e378bc2e9..165af60c1a7 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -640,10 +640,12 @@ std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::FloorIntToDec, "floor"}, {tipb::ScalarFuncSig::FloorIntToInt, "floor"}, {tipb::ScalarFuncSig::FloorDecToInt, "floorDecimalToInt"}, {tipb::ScalarFuncSig::FloorDecToDec, "floor"}, {tipb::ScalarFuncSig::FloorReal, "floor"}, - {tipb::ScalarFuncSig::RoundReal, "round"}, {tipb::ScalarFuncSig::RoundInt, "round"}, {tipb::ScalarFuncSig::RoundDec, "round"}, - //{tipb::ScalarFuncSig::RoundWithFracReal, "cast"}, - //{tipb::ScalarFuncSig::RoundWithFracInt, "cast"}, - //{tipb::ScalarFuncSig::RoundWithFracDec, "cast"}, + {tipb::ScalarFuncSig::RoundReal, "tidbRound"}, + {tipb::ScalarFuncSig::RoundInt, "tidbRound"}, + {tipb::ScalarFuncSig::RoundDec, "tidbRound"}, + // {tipb::ScalarFuncSig::RoundWithFracReal, "tidbRoundWithFrac"}, + // {tipb::ScalarFuncSig::RoundWithFracInt, "tidbRoundWithFrac"}, + // {tipb::ScalarFuncSig::RoundWithFracDec, "tidbRoundWithFrac"}, {tipb::ScalarFuncSig::Log1Arg, "log"}, {tipb::ScalarFuncSig::Log2Args, "log2args"}, diff --git a/dbms/src/Functions/FunctionsArithmetic.h b/dbms/src/Functions/FunctionsArithmetic.h index 446ae01db50..cd41d709a51 100644 --- a/dbms/src/Functions/FunctionsArithmetic.h +++ b/dbms/src/Functions/FunctionsArithmetic.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -525,39 +526,6 @@ struct ModuloImpl { using ResultType = typename NumberTraits::ResultOfModulo::Type; - template - static make_unsigned_t to_unsigned(const From & value) - { - using ReturnType = make_unsigned_t; - - if constexpr (is_signed_v) - { - if constexpr (is_boost_number_v) - { - // assert that negation of std::numeric_limits::min() will not result in overflow. - // TODO: find credible source that describes numeric limits of boost multiprecision *checked* integers. - static_assert(-std::numeric_limits::max() == std::numeric_limits::min()); - return static_cast(boost::multiprecision::abs(value)); - } - else - { - if (value < 0) - { - // both signed to unsigned conversion [1] and negation of unsigned integers [2] are well defined in C++. - // - // see: - // [1]: https://en.cppreference.com/w/c/language/conversion#Integer_conversions - // [2]: https://en.cppreference.com/w/cpp/language/operator_arithmetic#Unary_arithmetic_operators - return -static_cast(value); - } - else - return static_cast(value); - } - } - else - return static_cast(value); - } - template static inline Result apply(A a, B b) { @@ -581,8 +549,8 @@ struct ModuloImpl // convert to unsigned before computing. // we have to prevent wrong result like UInt64(5) = UInt64(5) % Int64(-3). // in MySQL, UInt64(5) % Int64(-3) evaluates to UInt64(2). - auto x = to_unsigned(a); - auto y = to_unsigned(b); + auto x = toSafeUnsigned(a); + auto y = toSafeUnsigned(b); auto result = static_cast(x % y); diff --git a/dbms/src/Functions/FunctionsRound.cpp b/dbms/src/Functions/FunctionsRound.cpp index ce3f5b50156..ba858b63483 100644 --- a/dbms/src/Functions/FunctionsRound.cpp +++ b/dbms/src/Functions/FunctionsRound.cpp @@ -23,6 +23,8 @@ void registerFunctionsRound(FunctionFactory & factory) /// Compatibility aliases. factory.registerFunction("ceiling", FunctionFactory::CaseInsensitive); factory.registerFunction("truncate", FunctionFactory::CaseInsensitive); + + factory.registerFunction(); } } diff --git a/dbms/src/Functions/FunctionsRound.h b/dbms/src/Functions/FunctionsRound.h index 5016ce323eb..aa51154aa73 100644 --- a/dbms/src/Functions/FunctionsRound.h +++ b/dbms/src/Functions/FunctionsRound.h @@ -1,19 +1,23 @@ #pragma once -#include +#include +#include +#include #include +#include #include -#include -#include -#include #include +#include +#include #include +#include #if __SSE4_1__ #include #endif +#include namespace DB { @@ -221,7 +225,7 @@ struct DecimalRoundingComputation { val = trunc(val); } - + if constexpr(scale_mode == ScaleMode::Positive) { @@ -633,12 +637,12 @@ struct Dispatcher typename ColumnDecimal::Container & vec_res = col_res->getData(); applyInternal(col, vec_res, col_res, block, scale_arg, result); } - else + else { auto col_res = ColumnVector::create(); typename ColumnVector::Container & vec_res = col_res->getData(); applyInternal(col, vec_res, col_res, block, scale_arg, result); } - + } private: template @@ -696,7 +700,7 @@ class FunctionRounding : public IFunction } return false; } - else + else { if (auto col = checkAndGetColumn>(block.getByPosition(arguments[0]).column.get())) { @@ -704,7 +708,7 @@ class FunctionRounding : public IFunction return true; } return false; - } + } } public: @@ -770,7 +774,7 @@ class FunctionRounding : public IFunction }; /** FunctionRounding can only cast typeA to typeA - * but TiDB may push down RoundDecimalToInt + * but TiDB may push down RoundDecimalToInt * (and this is the only round function that return type is different from arg type) * so we specialize RoundDecimalToInt and not use template */ @@ -792,7 +796,7 @@ class FunctionRoundingDecimalToInt : public IFunction Dispatcher::apply(block, col, arguments, result); return true; } - return false; + return false; } public: @@ -845,6 +849,379 @@ class FunctionRoundingDecimalToInt : public IFunction } }; +/** + * differences between prec/scale/frac: + * - prec/precision: number of decimal digits, including digits before and after decimal point. + * - scale: number of decimal digits after decimal point. + * - frac: the second argument of ROUND. + * - optional, and default to zero. + * - in MySQL, frac <= 30, which is decimal_max_scale. + * + * both prec and scale are non-negative, but frac can be negative. + */ +using FracType = Int64; + +template +struct TiDBIntegerRound +{ + static_assert(is_integer_v); + + static constexpr InputType eval(const InputType & input, FracType frac [[maybe_unused]]) + { + // TODO: RoundWithFrac. + assert(frac == 0); + + return input; + } +}; + +template +struct TiDBFloatingRound +{ + static_assert(std::is_floating_point_v); + + // EvalType is the type used in evaluations. + // in MySQL, floating round always returns Float64. + using EvalType = Float64; + + static constexpr EvalType eval(const InputType & input, FracType frac [[maybe_unused]]) + { + // TODO: RoundWithFrac. + assert(frac == 0); + + auto value = static_cast(input); + + // floating-point environment is thread-local, so `fesetround` is thread-safe. + std::fesetround(FE_TONEAREST); + return std::nearbyint(value); + } +}; + +// build constant table of up to Nth power of 10 at compile time. +template +struct ConstPowOf10 +{ + using ArrayType = std::array; + + static constexpr ArrayType build() + { + ArrayType result = {1}; + for (size_t i = 1; i <= N; ++i) + result[i] = result[i - 1] * static_cast(10); + return result; + } + + static constexpr ArrayType result = build(); +}; + +template +struct TiDBDecimalRound +{ + static_assert(IsDecimal); + + using UnsignedNativeType = make_unsigned_t; + using Pow = ConstPowOf10()>; + + static constexpr OutputType eval(const InputType & input, FracType frac [[maybe_unused]], ScaleType input_scale) + { + // TODO: RoundWithFrac. + assert(frac == 0); + + auto divider = Pow::result[input_scale]; + auto absolute_value = toSafeUnsigned(input.value); + + // "round half away from zero" + // examples: + // - input.value = 149, input_scale = 2, divider = 100, result = (149 + 100 / 2) / 100 = 1 + // - input.value = 150, input_scale = 2, divider = 100, result = (150 + 100 / 2) / 100 = 2 + auto absolute_result = static_cast((absolute_value + divider / 2) / divider); + + if (input.value < 0) + return -static_cast(absolute_result); + else + return static_cast(absolute_result); + } +}; + +static FracType getFracFromConstColumn(const ColumnConst * column) +{ + FracType result; + auto frac_field = column->getField(); + + if (!frac_field.tryGet(result)) + { + // maybe field is unsigned. + static_assert(is_signed_v); + make_unsigned_t unsigned_frac; + + if (!frac_field.tryGet(unsigned_frac)) + { + throw Exception( + fmt::format("Illegal frac column with type {}, expected to const Int64/UInt64", column->getField().getTypeName()), + ErrorCodes::ILLEGAL_COLUMN); + } + + result = static_cast(unsigned_frac); + } + + // in MySQL, frac is clamped to 30, which is identical to decimal_max_scale. + // frac is signed but decimal_max_scale is unsigned, so we have to cast before + // comparison. + if (result > static_cast(decimal_max_scale)) + result = decimal_max_scale; + + return result; +} + +struct TiDBRoundPrecisionInferer +{ + static std::tuple infer( + PrecType prec, ScaleType scale, FracType frac [[maybe_unused]], bool is_const_frac [[maybe_unused]]) + { + // TODO: RoundWithFrac. + assert(is_const_frac); + assert(frac == 0); + + assert(prec >= scale); + PrecType new_prec = prec - scale; + + // +1 for possible overflow, e.g. round(99999.9) => 100000 + if (scale > 0) + new_prec += 1; + + return std::make_tuple(new_prec, 0); + } +}; + +template +struct TiDBRound +{ + static void apply(const ColumnPtr & input_column_, const ColumnPtr & frac_column_, MutableColumnPtr & output_column_, + ScaleType input_scale [[maybe_unused]], ScaleType output_scale [[maybe_unused]]) + { + auto input_column = checkAndGetColumn(input_column_.get()); + auto frac_column = checkAndGetColumn(frac_column_.get()); + + if (input_column == nullptr) + throw Exception(fmt::format("Illegal column {} for the first argument of function round", input_column_->getName()), + ErrorCodes::ILLEGAL_COLUMN); + if (frac_column == nullptr) + throw Exception(fmt::format("Illegal column {} for the second argument of function round", frac_column_->getName()), + ErrorCodes::ILLEGAL_COLUMN); + + // TODO: RoundWithFrac. + assert(frac_column->isColumnConst()); + auto frac_value = getFracFromConstColumn(frac_column); + assert(frac_value == 0); + + // TODO: const input column. + assert(!input_column->isColumnConst()); + + auto & input_data = input_column->getData(); + size_t size = input_data.size(); + + auto output_column = typeid_cast(output_column_.get()); + assert(output_column != nullptr); + + auto & output_data = output_column->getData(); + output_data.resize(size); + + for (size_t i = 0; i < size; ++i) + { + // TODO: RoundWithFrac. + if constexpr (std::is_floating_point_v) + output_data[i] = TiDBFloatingRound::eval(input_data[i], 0); + else if constexpr (IsDecimal) + output_data[i] = TiDBDecimalRound::eval(input_data[i], 0, input_scale); + else + output_data[i] = TiDBIntegerRound::eval(input_data[i], 0); + } + } +}; + +/** + * round(x, d) for TiDB. + */ +class FunctionTiDBRoundWithFrac : public IFunction +{ +public: + static constexpr auto name = "tidbRoundWithFrac"; + + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + size_t getNumberOfArguments() const override { return 2; } + bool useDefaultImplementationForNulls() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } + bool hasInformationAboutMonotonicity() const override { return true; } + + Monotonicity getMonotonicityForRange(const IDataType &, const Field &, const Field &) const override { return {true, true, true}; } + +private: + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + checkArguments(arguments); + + auto input_type = arguments[0].type; + + if (input_type->isInteger()) + return input_type; + else if (input_type->isFloatingPoint()) + { + // in MySQL, floating round always returns Float64. + return std::make_shared(); + } + else + { + assert(input_type->isDecimal()); + + auto prec = getDecimalPrecision(*input_type, std::numeric_limits::max()); + auto scale = getDecimalScale(*input_type, std::numeric_limits::max()); + assert(prec != std::numeric_limits::max()); + assert(scale != std::numeric_limits::max()); + + // if is_const_frac is false, the value of frac will be ignored. + FracType frac = 0; + bool is_const_frac = true; + + auto frac_column = arguments[1].column.get(); + if (frac_column->isColumnConst()) + { + auto column = typeid_cast(frac_column); + assert(column != nullptr); + + is_const_frac = true; + frac = getFracFromConstColumn(column); + } + else + is_const_frac = false; + + auto [new_prec, new_scale] = TiDBRoundPrecisionInferer::infer(prec, scale, frac, is_const_frac); + return createDecimal(new_prec, new_scale); + } + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override + { + ColumnsWithTypeAndName columns; + for (const auto & position : arguments) + columns.push_back(block.getByPosition(position)); + + auto return_type = getReturnTypeImpl(columns); + auto result_column = return_type->createColumn(); + + auto input_type = columns[0].type; + auto input_column = columns[0].column; + auto frac_type = columns[1].type; + auto frac_column = columns[1].column; + + auto input_scale = getDecimalScale(*input_type, 0); + auto result_scale = getDecimalScale(*return_type, 0); + + checkInputTypeAndApply(input_type, return_type, frac_type, input_column, frac_column, result_column, input_scale, result_scale); + + block.getByPosition(result).column = std::move(result_column); + } + + void checkInputTypeAndApply(const DataTypePtr & input_type, const DataTypePtr & return_type, const DataTypePtr & frac_type, + const ColumnPtr & input_column, const ColumnPtr & frac_column, MutableColumnPtr & result_column, ScaleType input_scale, + ScaleType result_scale) + { + if (!castTypeToEither( + input_type.get(), [&](const auto & input_type, bool) { + using InputDataType = std::decay_t; + + checkReturnTypeAndApply( + return_type, frac_type, input_column, frac_column, result_column, input_scale, result_scale); + + return true; + })) + { + throw Exception(fmt::format("Illegal column type {} for the first argument of function {}", input_type->getName(), getName()), + ErrorCodes::ILLEGAL_COLUMN); + } + } + + template + void checkReturnTypeAndApply(const DataTypePtr & return_type, const DataTypePtr & frac_type, const ColumnPtr & input_column, + const ColumnPtr & frac_column, MutableColumnPtr & result_column, ScaleType input_scale, ScaleType result_scale) + { + if (!castTypeToEither( + return_type.get(), [&](const auto & return_type, bool) { + using ReturnDataType = std::decay_t; + + return checkFracTypeAndApply( + frac_type, input_column, frac_column, result_column, input_scale, result_scale); + })) + { + throw TiFlashException(fmt::format("Unexpected return type for function {}", getName()), Errors::Coprocessor::Internal); + } + } + + template + bool checkFracTypeAndApply(const DataTypePtr & frac_type, const ColumnPtr & input_column, const ColumnPtr & frac_column, + MutableColumnPtr & result_column, ScaleType input_scale [[maybe_unused]], ScaleType result_scale [[maybe_unused]]) + { + if constexpr ((std::is_floating_point_v && !std::is_same_v) + || (IsDecimal && !IsDecimal) || (is_integer_v && !std::is_same_v)) + return false; + else + { + if (!castTypeToEither(frac_type.get(), [&](const auto & frac_type, bool) { + using FracDataType = std::decay_t; + + checkColumnsAndApply( + input_column, frac_column, result_column, input_scale, result_scale); + + return true; + })) + { + throw Exception( + fmt::format("Illegal column type {} for the second argument of function {}", frac_type->getName(), getName()), + ErrorCodes::ILLEGAL_COLUMN); + } + + return true; + } + } + + template + void checkColumnsAndApply(const ColumnPtr & input_column, const ColumnPtr & frac_column, MutableColumnPtr & result_column, + ScaleType input_scale, ScaleType result_scale) + { + using InputColumn = std::conditional_t, ColumnDecimal, ColumnVector>; + using ResultColumn = std::conditional_t, ColumnDecimal, ColumnVector>; + + // TODO: RoundWithFrac + assert(!input_column->isColumnConst()); + assert(frac_column->isColumnConst()); + + TiDBRound::apply( + input_column, frac_column, result_column, input_scale, result_scale); + } + + void checkArguments(const ColumnsWithTypeAndName & arguments) const + { + if (arguments.size() != getNumberOfArguments()) + throw Exception(fmt::format("Number of arguments for function {} doesn't match: passed {}, should be {}", getName(), + arguments.size(), getNumberOfArguments()), + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + auto input_type = arguments[0].type; + if (!input_type->isNumber() && !input_type->isDecimal()) + throw Exception(fmt::format("Illegal type {} of first argument of function {}", input_type->getName(), getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + // the second argument frac must be integers. + auto frac_type = arguments[1].type; + if (!frac_type->isInteger()) + throw Exception(fmt::format("Illegal type {} of second argument of function {}", frac_type->getName(), getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } +}; + struct NameRoundToExp2 { static constexpr auto name = "roundToExp2"; }; struct NameRoundDuration { static constexpr auto name = "roundDuration"; }; struct NameRoundAge { static constexpr auto name = "roundAge"; }; diff --git a/dbms/src/Functions/tests/gtest_functions_round.cpp b/dbms/src/Functions/tests/gtest_functions_round.cpp new file mode 100644 index 00000000000..e49b738d227 --- /dev/null +++ b/dbms/src/Functions/tests/gtest_functions_round.cpp @@ -0,0 +1,376 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace tests +{ + +namespace +{ + +template +using Limits = std::numeric_limits; + +using DecimalField32 = DecimalField; +using DecimalField64 = DecimalField; +using DecimalField128 = DecimalField; +using DecimalField256 = DecimalField; + +template +struct ToDecimalType; + +template +struct ToDecimalType> +{ + using type = T; +}; + +// parse a string into decimal. You should the format of `literal` is valid. +template +DecimalField parseDecimal(const std::string & literal) +{ + using DecimalType = typename ToDecimalType::type; + using NativeType = typename DecimalType::NativeType; + + static_assert(is_signed_v); + + // the position of decimal point "." in `literal. + // `literal.size() - dot_index - 1` will be decimal scale. + assert(literal.size() > 0); + size_t dot_index = literal.size() - 1; + + bool negative = false; + NativeType value = 0; + for (size_t i = 0; i < literal.size(); ++i) + { + switch (literal[i]) + { + case '+': + // ignore plus sign. e.g. "+10000" = "10000". + [[fallthrough]]; + case '\'': + // use "'" as separator. e.g. 1'000'000'000. + continue; + case '-': + negative = true; + break; + case '.': + dot_index = i; + break; + + default: + assert(isdigit(literal[i])); + value = value * 10 + (literal[i] - '0'); + } + } + + if (negative) + value = -value; + + ScaleType scale = literal.size() - dot_index - 1; + return DecimalField(value, scale); +} + +// parse an array of strings into array of decimals. +template +auto parseDecimalArray(const std::vector> & literals) +{ + std::vector> result; + + result.reserve(literals.size()); + for (const auto & literal : literals) + { + if (literal.has_value()) + result.push_back(parseDecimal(literal.value())); + else + result.push_back(std::nullopt); + } + + return result; +} + +template +struct TestData +{ + using InputType = Input; + using OutputType = Output; + + PrecType input_prec; + ScaleType input_scale; + PrecType output_prec; + ScaleType output_scale; + + std::vector> input; + std::vector> output; +}; + +// `id` is used to distinguish between different test cases with same input and output types. +template +struct TestCase +{ + using InputType = Input; + using OutputType = Output; + static constexpr size_t id = id_; +}; + +using TestCases = ::testing::Types, TestCase, TestCase, + TestCase, TestCase, TestCase, + TestCase, TestCase, TestCase, + TestCase, TestCase>; + +template +auto getTestData(); + +template <> +auto getTestData() +{ + return TestData{0, 0, 0, 0, {0, 1, -1, Limits::max(), Limits::min(), std::nullopt}, + {0, 1, -1, Limits::max(), Limits::min(), std::nullopt}}; +}; + +template <> +auto getTestData() +{ + return TestData{0, 0, 0, 0, {0, 1, Limits::max(), std::nullopt}, {0, 1, Limits::max(), std::nullopt}}; +} + +template <> +auto getTestData() +{ + double large_value = std::pow(10.0, 100.0); + return TestData{0, 0, 0, 0, + {-5.5, -4.5, -3.5, -2.5, -1.5, -0.6, -0.5, -0.4, 0.0, 0.4, 0.5, 0.6, 1.5, 2.5, 3.5, 4.5, 5.5, large_value, std::nullopt}, + {-6.0, -4.0, -4.0, -2.0, -2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 4.0, 4.0, 6.0, large_value, std::nullopt}}; +} + +// Decimal(9, 0) -> Decimal(9, 0) +template <> +auto getTestData() +{ + return TestData{9, 0, 9, 0, + parseDecimalArray({"0", "1", "-1", "9'9999'9999", "-9'9999'9999", std::nullopt}), + parseDecimalArray({"0", "1", "-1", "9'9999'9999", "-9'9999'9999", std::nullopt})}; +} + +// Decimal(9, 1) -> Decimal(9, 0) +template <> +auto getTestData() +{ + return TestData{9, 1, 9, 0, + parseDecimalArray({"-2.5", "-1.5", "-0.6", "-0.5", "-0.4", "0.0", "0.4", "0.5", "0.6", "1.5", "2.5", "9999'9999.9", + "-9999'9999.9", std::nullopt}), + parseDecimalArray( + {"-3", "-2", "-1", "-1", "0", "0", "0", "1", "1", "2", "3", "1'0000'0000", "-1'0000'0000", std::nullopt})}; +} + +// Decimal(9, 9) -> Decimal(1, 0) +template <> +auto getTestData() +{ + return TestData{9, 9, 1, 0, + parseDecimalArray({"0.0000'0000'0", "0.0000'0000'1", "-0.0000'0000'1", "0.50000'0000", "-0.50000'0000", + "0.9999'9999'9", "-0.9999'9999'9", std::nullopt}), + parseDecimalArray({"0", "0", "0", "1", "-1", "1", "-1", std::nullopt})}; +} + +// Decimal(65, 0) -> Decimal(65, 0) +template <> +auto getTestData() +{ + std::string extreme(65, '9'); + + return TestData{65, 0, 65, 0, + parseDecimalArray({"0", "1", "-1", extreme, "-" + extreme, std::nullopt}), + parseDecimalArray({"0", "1", "-1", extreme, "-" + extreme, std::nullopt})}; +} + +// Decimal(38, 1) -> Decimal(38, 0) +template <> +auto getTestData() +{ + auto extreme = std::string(37, '9') + ".9"; + auto rounded_extreme = '1' + std::string(37, '0'); + + return TestData{38, 1, 38, 0, + parseDecimalArray( + {"-2.5", "-1.5", "-0.6", "-0.5", "-0.4", "0.0", "0.4", "0.5", "0.6", "1.5", "2.5", extreme, "-" + extreme, std::nullopt}), + parseDecimalArray( + {"-3", "-2", "-1", "-1", "0", "0", "0", "1", "1", "2", "3", rounded_extreme, "-" + rounded_extreme, std::nullopt})}; +} + +// Decimal(18, 10) -> Decimal(9, 0) +template <> +auto getTestData() +{ + auto zeros = '.' + std::string(10, '0'); + auto half = ".5" + std::string(9, '5'); + auto extreme = std::string(8, '9') + '.' + std::string(10, '9'); + auto rounded_extreme = '1' + std::string(8, '0'); + + return TestData{18, 10, 9, 0, + parseDecimalArray({"0" + zeros, "0" + half, "-0" + half, extreme, "-" + extreme, std::nullopt}), + parseDecimalArray({"0", "1", "-1", rounded_extreme, "-" + rounded_extreme, std::nullopt})}; +} + +// Decimal(25, 10) -> Decimal(16, 0) +template <> +auto getTestData() +{ + auto zeros = '.' + std::string(10, '0'); + auto half = ".5" + std::string(9, '5'); + auto extreme = std::string(15, '9') + '.' + std::string(10, '9'); + auto rounded_extreme = '1' + std::string(15, '0'); + + return TestData{25, 10, 16, 0, + parseDecimalArray({"0" + zeros, "0" + half, "-0" + half, extreme, "-" + extreme, std::nullopt}), + parseDecimalArray({"0", "1", "-1", rounded_extreme, "-" + rounded_extreme, std::nullopt})}; +} + +// Decimal(40, 30) -> Decimal(11, 0) +template <> +auto getTestData() +{ + auto zeros = '.' + std::string(30, '0'); + auto half = ".5" + std::string(29, '5'); + auto extreme = std::string(10, '9') + '.' + std::string(30, '9'); + auto rounded_extreme = '1' + std::string(10, '0'); + + return TestData{40, 30, 11, 0, + parseDecimalArray({"0" + zeros, "0" + half, "-0" + half, extreme, "-" + extreme, std::nullopt}), + parseDecimalArray({"0", "1", "-1", rounded_extreme, "-" + rounded_extreme, std::nullopt})}; +} + +} // namespace + +template +class TestFunctionsRound : public ::testing::Test +{ +public: + static void SetUpTestCase() + { + try + { + registerFunctions(); + } + catch (Exception &) + { + // maybe other tests have already registered. + } + } +}; + +TYPED_TEST_CASE(TestFunctionsRound, TestCases); + +TYPED_TEST(TestFunctionsRound, TiDBRound) +try +{ + // prepare test data. + + using InputType = typename TypeParam::InputType; + using OutputType = typename TypeParam::OutputType; + + auto data = getTestData(); + size_t size = data.input.size(); + + // determine data type. + + DataTypePtr data_type; + if constexpr (isDecimalField()) + data_type = std::make_shared::type>>(data.input_prec, data.input_scale); + else + data_type = std::make_shared>(); + data_type = makeNullable(data_type); + + // construct argument columns: `input` and `frac`. + + auto input_column = data_type->createColumn(); + + for (const auto & value : data.input) + { + if (value.has_value()) + input_column->insert(Field(value.value())); + else + input_column->insert(Null()); + } + + auto input = ColumnWithTypeAndName{std::move(input_column), data_type, "input"}; + + auto frac_type = std::make_shared(); + auto frac_column = frac_type->createColumnConst(size, Field(static_cast(0))); + auto frac = ColumnWithTypeAndName{std::move(frac_column), frac_type, "frac"}; + + // build function. + + const auto context = TiFlashTestEnv::getContext(); + auto & factory = FunctionFactory::instance(); + + auto builder = factory.tryGet("tidbRoundWithFrac", context); + ASSERT_NE(builder, nullptr); + + auto function = builder->build({input, frac}); + ASSERT_NE(function, nullptr); + + // prepare block. + + Block block; + block.insert(input); + block.insert(frac); + block.insert({nullptr, function->getReturnType(), "result"}); + + // execute function. + + function->execute(block, {block.getPositionByName("input"), block.getPositionByName("frac")}, block.getPositionByName("result")); + + // check result. + + auto result = block.getByName("result"); + ASSERT_NE(result.column, nullptr); + + if constexpr (isDecimalField()) + { + auto result_type = result.type; + if (auto actual = checkAndGetDataType(result_type.get())) + result_type = actual->getNestedType(); + + ASSERT_EQ(getDecimalPrecision(*result_type, 0), data.output_prec); + ASSERT_EQ(getDecimalScale(*result_type, 0), data.output_scale); + } + + ASSERT_EQ(block.rows(), size); + + Field result_field; + for (size_t i = 0; i < size; ++i) + { + result.column->get(i, result_field); + + if (data.output[i].has_value()) + { + ASSERT_FALSE(result_field.isNull()) << "index = " << i; + + auto got = result_field.safeGet(); + auto expected = data.output[i].value(); + + if constexpr (isDecimalField()) + { + ASSERT_EQ(got.getScale(), expected.getScale()) << "index = " << i; + ASSERT_EQ(got.getValue(), expected.getValue()) << "index = " << i; + } + else + ASSERT_EQ(got, expected) << "index = " << i; + } + else + ASSERT_TRUE(result_field.isNull()) << "index = " << i; + } +} +CATCH + +} // namespace tests + +} // namespace DB diff --git a/tests/fullstack-test/expr/round.test b/tests/fullstack-test/expr/round.test new file mode 100644 index 00000000000..cebbec7b85b --- /dev/null +++ b/tests/fullstack-test/expr/round.test @@ -0,0 +1,292 @@ +mysql> drop table if exists test.int8 +mysql> create table test.int8 (id int, a tinyint) +mysql> alter table test.int8 set tiflash replica 1 +mysql> insert into test.int8 values (1, 0), (2, 1), (3, -1), (4, 127), (5, -128), (6, null) +func> wait_table test int8 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.int8 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | 0 | +| 2 | 1 | +| 3 | -1 | +| 4 | 127 | +| 5 | -128 | +| 6 | NULL | ++------+---------------+ + +mysql> drop table if exists test.uint8 +mysql> create table test.uint8 (id int, a tinyint unsigned) +mysql> alter table test.uint8 set tiflash replica 1 +mysql> insert into test.uint8 values (1, 0), (2, 1), (3, 255), (4, null) +func> wait_table test uint8 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.uint8 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | 0 | +| 2 | 1 | +| 3 | 255 | +| 4 | NULL | ++------+---------------+ + +mysql> drop table if exists test.int64 +mysql> create table test.int64 (id int, a bigint) +mysql> alter table test.int64 set tiflash replica 1 +mysql> insert into test.int64 values (1, 0), (2, 1), (3, -1), (4, 9223372036854775807), (5, -9223372036854775808), (6, null) +func> wait_table test int64 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.int64 group by id order by id ++------+----------------------+ +| id | sum(round(a)) | ++------+----------------------+ +| 1 | 0 | +| 2 | 1 | +| 3 | -1 | +| 4 | 9223372036854775807 | +| 5 | -9223372036854775808 | +| 6 | NULL | ++------+----------------------+ + +mysql> drop table if exists test.uint64 +mysql> create table test.uint64 (id int, a bigint unsigned) +mysql> alter table test.uint64 set tiflash replica 1 +mysql> insert into test.uint64 values (1, 0), (2, 1), (3, 18446744073709551615), (4, null) +func> wait_table test uint64 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.uint64 group by id order by id ++------+----------------------+ +| id | sum(round(a)) | ++------+----------------------+ +| 1 | 0 | +| 2 | 1 | +| 3 | 18446744073709551615 | +| 4 | NULL | ++------+----------------------+ + +mysql> drop table if exists test.float32 +mysql> create table test.float32 (id int, a float) +mysql> alter table test.float32 set tiflash replica 1 +mysql> insert into test.float32 values (1, -5.5), (2, -4.5), (3, -3.5), (4, -2.5), (5, -1.5), (6, -0.6), (7, -0.5), (8, -0.4), (9, 0), (10, 0.4), (11, 0.5), (12, 0.6), (13, 1.5), (14, 2.5), (15, 3.5), (16, 4.5), (17, 5.5), (18, 1e30), (19, -1e30), (20, null) +func> wait_table test float32 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.float32 group by id order by id ++------+------------------------+ +| id | sum(round(a)) | ++------+------------------------+ +| 1 | -6 | +| 2 | -4 | +| 3 | -4 | +| 4 | -2 | +| 5 | -2 | +| 6 | -1 | +| 7 | 0 | +| 8 | 0 | +| 9 | 0 | +| 10 | 0 | +| 11 | 0 | +| 12 | 1 | +| 13 | 2 | +| 14 | 2 | +| 15 | 4 | +| 16 | 4 | +| 17 | 6 | +| 18 | 1.0000000150474662e30 | +| 19 | -1.0000000150474662e30 | +| 20 | NULL | ++------+------------------------+ + +mysql> drop table if exists test.float64 +mysql> create table test.float64 (id int, a double) +mysql> alter table test.float64 set tiflash replica 1 +mysql> insert into test.float64 values (1, -5.5), (2, -4.5), (3, -3.5), (4, -2.5), (5, -1.5), (6, -0.6), (7, -0.5), (8, -0.4), (9, 0), (10, 0.4), (11, 0.5), (12, 0.6), (13, 1.5), (14, 2.5), (15, 3.5), (16, 4.5), (17, 5.5), (18, 1e100), (19, -1e100), (20, null) +func> wait_table test float64 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.float64 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | -6 | +| 2 | -4 | +| 3 | -4 | +| 4 | -2 | +| 5 | -2 | +| 6 | -1 | +| 7 | 0 | +| 8 | 0 | +| 9 | 0 | +| 10 | 0 | +| 11 | 0 | +| 12 | 1 | +| 13 | 2 | +| 14 | 2 | +| 15 | 4 | +| 16 | 4 | +| 17 | 6 | +| 18 | 1e100 | +| 19 | -1e100 | +| 20 | NULL | ++------+---------------+ + +mysql> drop table if exists test.decimal32 +mysql> create table test.decimal32 (id int, a decimal(9, 0)) +mysql> alter table test.decimal32 set tiflash replica 1 +mysql> insert into test.decimal32 values (1, 0), (2, 1), (3, -1), (4, 999999999), (5, -999999999), (6, null) +func> wait_table test decimal32 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal32 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | 0 | +| 2 | 1 | +| 3 | -1 | +| 4 | 999999999 | +| 5 | -999999999 | +| 6 | NULL | ++------+---------------+ + +mysql> drop table if exists test.decimal32 +mysql> create table test.decimal32 (id int, a decimal(9, 1)) +mysql> alter table test.decimal32 set tiflash replica 1 +mysql> insert into test.decimal32 values (1, -5.5), (2, -4.5), (3, -3.5), (4, -2.5), (5, -1.5), (6, -0.6), (7, -0.5), (8, -0.4), (9, 0), (10, 0.4), (11, 0.5), (12, 0.6), (13, 1.5), (14, 2.5), (15, 3.5), (16, 4.5), (17, 5.5), (18, 99999999.9), (19, -99999999.9), (20, null) +func> wait_table test decimal32 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal32 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | -6 | +| 2 | -5 | +| 3 | -4 | +| 4 | -3 | +| 5 | -2 | +| 6 | -1 | +| 7 | -1 | +| 8 | 0 | +| 9 | 0 | +| 10 | 0 | +| 11 | 1 | +| 12 | 1 | +| 13 | 2 | +| 14 | 3 | +| 15 | 4 | +| 16 | 5 | +| 17 | 6 | +| 18 | 100000000 | +| 19 | -100000000 | +| 20 | NULL | ++------+---------------+ + +mysql> drop table if exists test.decimal32 +mysql> create table test.decimal32 (id int, a decimal(9, 9)) +mysql> alter table test.decimal32 set tiflash replica 1 +mysql> insert into test.decimal32 values (1, 0), (2, -0.000000001), (3, 0.000000001), (4, 0.5), (5, -0.5), (6, 0.999999999), (7, -0.999999999), (8, null) +func> wait_table test decimal32 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal32 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | 0 | +| 2 | 0 | +| 3 | 0 | +| 4 | 1 | +| 5 | -1 | +| 6 | 1 | +| 7 | -1 | +| 8 | NULL | ++------+---------------+ + +mysql> drop table if exists test.decimal256 +mysql> create table test.decimal256 (id int, a decimal(65, 0)) +mysql> alter table test.decimal256 set tiflash replica 1 +mysql> insert into test.decimal256 values (1, 0), (2, 1), (3, -1), (4, 99999999999999999999999999999999999999999999999999999999999999999), (5, -99999999999999999999999999999999999999999999999999999999999999999), (6, null) +func> wait_table test decimal256 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal256 group by id order by id ++------+--------------------------------------------------------------------+ +| id | sum(round(a)) | ++------+--------------------------------------------------------------------+ +| 1 | 0 | +| 2 | 1 | +| 3 | -1 | +| 4 | 99999999999999999999999999999999999999999999999999999999999999999 | +| 5 | -99999999999999999999999999999999999999999999999999999999999999999 | +| 6 | NULL | ++------+--------------------------------------------------------------------+ + +mysql> drop table if exists test.decimal256 +mysql> create table test.decimal256 (id int, a decimal(65, 1)) +mysql> alter table test.decimal256 set tiflash replica 1 +mysql> insert into test.decimal256 values (1, -5.5), (2, -4.5), (3, -3.5), (4, -2.5), (5, -1.5), (6, -0.6), (7, -0.5), (8, -0.4), (9, 0), (10, 0.4), (11, 0.5), (12, 0.6), (13, 1.5), (14, 2.5), (15, 3.5), (16, 4.5), (17, 5.5), (18, 9999999999999999999999999999999999999999999999999999999999999999.9), (19, -9999999999999999999999999999999999999999999999999999999999999999.9), (20, null) +func> wait_table test decimal256 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal256 group by id order by id ++------+--------------------------------------------------------------------+ +| id | sum(round(a)) | ++------+--------------------------------------------------------------------+ +| 1 | -6 | +| 2 | -5 | +| 3 | -4 | +| 4 | -3 | +| 5 | -2 | +| 6 | -1 | +| 7 | -1 | +| 8 | 0 | +| 9 | 0 | +| 10 | 0 | +| 11 | 1 | +| 12 | 1 | +| 13 | 2 | +| 14 | 3 | +| 15 | 4 | +| 16 | 5 | +| 17 | 6 | +| 18 | 10000000000000000000000000000000000000000000000000000000000000000 | +| 19 | -10000000000000000000000000000000000000000000000000000000000000000 | +| 20 | NULL | ++------+--------------------------------------------------------------------+ + +mysql> drop table if exists test.decimal256 +mysql> create table test.decimal256 (id int, a decimal(40, 10)) +mysql> alter table test.decimal256 set tiflash replica 1 +mysql> insert into test.decimal256 values (1, -5.5), (2, -4.5), (3, -3.5), (4, -2.5), (5, -1.5), (6, -0.6), (7, -0.5), (8, -0.4), (9, 0), (10, 0.4), (11, 0.5), (12, 0.6), (13, 1.5), (14, 2.5), (15, 3.5), (16, 4.5), (17, 5.5), (18, 9999999999.999999999999999999999999999999), (19, -9999999999.999999999999999999999999999999), (20, 1.000000000000000000000000000001), (21, -1.000000000000000000000000000001), (22, null) +func> wait_table test decimal256 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal256 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | -6 | +| 2 | -5 | +| 3 | -4 | +| 4 | -3 | +| 5 | -2 | +| 6 | -1 | +| 7 | -1 | +| 8 | 0 | +| 9 | 0 | +| 10 | 0 | +| 11 | 1 | +| 12 | 1 | +| 13 | 2 | +| 14 | 3 | +| 15 | 4 | +| 16 | 5 | +| 17 | 6 | +| 18 | 10000000000 | +| 19 | -10000000000 | +| 20 | 1 | +| 21 | -1 | +| 22 | NULL | ++------+---------------+ + +mysql> drop table if exists test.decimal128 +mysql> create table test.decimal128 (id int, a decimal(30, 30)) +mysql> alter table test.decimal128 set tiflash replica 1 +mysql> insert into test.decimal128 values (1, 0), (2, -0.000000000000000000000000000001), (3, 0.000000000000000000000000000001), (4, 0.5), (5, -0.5), (6, 0.999999999999999999999999999999), (7, -0.999999999999999999999999999999), (8, null) +func> wait_table test decimal128 +mysql> set @@session.tidb_isolation_read_engines='tiflash'; select id, sum(round(a)) from test.decimal128 group by id order by id ++------+---------------+ +| id | sum(round(a)) | ++------+---------------+ +| 1 | 0 | +| 2 | 0 | +| 3 | 0 | +| 4 | 1 | +| 5 | -1 | +| 6 | 1 | +| 7 | -1 | +| 8 | NULL | ++------+---------------+