diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 7e2377ae8b8..7b9bfef0964 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -177,6 +177,8 @@ class ColumnDecimal final : public COWPtrHelper::prec is 4. +// This is a little confusing because we will add 1 when return result to client. +// Here we make sure TiFlash code is clean and will fix TiDB later. template struct IntPrec { @@ -27,7 +33,7 @@ struct IntPrec template <> struct IntPrec { - static constexpr PrecType prec = 4; + static constexpr PrecType prec = 3; }; template <> struct IntPrec @@ -37,7 +43,7 @@ struct IntPrec template <> struct IntPrec { - static constexpr PrecType prec = 6; + static constexpr PrecType prec = 5; }; template <> struct IntPrec @@ -47,7 +53,7 @@ struct IntPrec template <> struct IntPrec { - static constexpr PrecType prec = 11; + static constexpr PrecType prec = 10; }; template <> struct IntPrec @@ -57,7 +63,7 @@ struct IntPrec template <> struct IntPrec { - static constexpr PrecType prec = 20; + static constexpr PrecType prec = 19; }; template <> struct IntPrec @@ -65,6 +71,18 @@ struct IntPrec static constexpr PrecType prec = 20; }; +template <> +struct IntPrec +{ + static constexpr PrecType prec = 39; +}; + +template <> +struct IntPrec +{ + static constexpr PrecType prec = 78; +}; + // 1) If the declared type of both operands of a dyadic arithmetic operator is exact numeric, then the declared // type of the result is an implementation-defined exact numeric type, with precision and scale determined as // follows: @@ -359,6 +377,8 @@ class DecimalMaxValue final : public ext::Singleton public: static Int256 get(PrecType idx) { + // In case DecimalMaxValue::get(IntPrec::prec), where IntPrec::prec > 65. + assert(idx <= decimal_max_prec); return instance().getInternal(idx); } @@ -384,12 +404,18 @@ class DecimalMaxValue final : public ext::Singleton // In some case, getScaleMultiplier and its callee may not be auto inline by the compiler. // This may hurt performance. __attribute__((flatten)) tells compliler to inline the callee of this function. -template +template > * = nullptr> __attribute__((flatten)) inline typename T::NativeType getScaleMultiplier(ScaleType scale) { return static_cast(DecimalMaxValue::get(scale) + 1); } +template > * = nullptr> +__attribute__((flatten)) inline T getScaleMultiplier(ScaleType scale) +{ + return static_cast(DecimalMaxValue::get(scale) + 1); +} + template inline void checkDecimalOverflow(Decimal v, PrecType prec) { diff --git a/dbms/src/Functions/FunctionsTiDBConversion.h b/dbms/src/Functions/FunctionsTiDBConversion.h index bf31705b4a1..fce22b58675 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.h +++ b/dbms/src/Functions/FunctionsTiDBConversion.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -58,6 +59,11 @@ enum CastError OVERFLOW_ERR, }; +namespace +{ +constexpr static Int64 pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000}; +} + /// cast int/real/decimal/time as string template struct TiDBConvertToString @@ -810,66 +816,68 @@ struct TiDBConvertToFloat /// cast int/real/decimal/enum/string/time/string as decimal // todo TiKV does not check unsigned flag but TiDB checks, currently follow TiKV's code, maybe changed latter -template +// There are two optimizations in TiDBConvertToDecimal: +// 1. Skip overflow check if possible, such as cast(tiny_int_val as decimal(10, 0)), +// we can skip check overflow, because max_tiny_int(127) < to_max_val(10^9). +// 2. Use appropriate type for multiplication of from_int_val and scale_mul(which is 10^abs(scale_diff)). +// The original implementation always uses Int256, which is very slow. +// The general idea is: +// 1. If from_type_prec + scale_diff <= to_type_prec, we can skip overflow check. +// Because the max value of from type is less than the max value of to type, so no overflow will happen. +// 2. CastInternalType is the int type with minimum prec which satisfies: from_type_prec + scale_diff <= IntPrec::prec - 1. +// So on the one hand CastInternalType can hold both from_int_value and the result of multiplication of from_int_val and scale_mul, +// on the other hand the multiplication is as fast as possible. +// NOTE: scale_diff = to_type_scale - from_type_scale. +// NOTE: The above two optimizations only take effects when from type is int/decimal/date/dateimte. +// The logic of cast doesn't care about CastInternalType(Int512) and can_skip_check_overflow(false) at all when from_type is real or string. +template struct TiDBConvertToDecimal { using FromFieldType = typename FromDataType::FieldType; template - static U toTiDBDecimalInternal(T int_value, PrecType prec, ScaleType scale, const Context & context) + static U toTiDBDecimalInternal(T int_value, const CastInternalType & max_value, const CastInternalType & scale_mul, const Context & context) { // int_value is the value that exposes to user. Such as cast(val to decimal), val is the int_value which used by user. // And val * scale_mul is the scaled_value, which is stored in ColumnDecimal internally. static_assert(std::is_integral_v); using UType = typename U::NativeType; - UType scale_mul = getScaleMultiplier(scale); - Int256 scaled_value = static_cast(int_value) * static_cast(scale_mul); - Int256 scaled_max_value = DecimalMaxValue::get(prec); - - if (scaled_value > scaled_max_value || scaled_value < -scaled_max_value) - { - context.getDAGContext()->handleOverflowError("cast to decimal", Errors::Types::Truncated); - if (int_value > 0) - return static_cast(scaled_max_value); - else - return static_cast(-scaled_max_value); - } - - return static_cast(scaled_value); + CastInternalType scaled_value = static_cast(int_value) * scale_mul; + return handleOverflowErrorForIntAndDecimal(context, scaled_value, max_value, "cast to decimal"); } template - static U toTiDBDecimal(MyDateTime & date_time, PrecType prec, ScaleType scale, int fsp, const Context & context) + static U toTiDBDecimal(MyDateTime & date_time, const CastInternalType & max_value, ScaleType from_scale, ScaleType to_scale, const CastInternalType & scale_mul, int fsp, const Context & context) { UInt64 value_without_fsp = date_time.year * 10000000000ULL + date_time.month * 100000000ULL + date_time.day * 1000000ULL - + date_time.hour * 10000ULL + date_time.minute * 100 + date_time.second; + + date_time.hour * 10000ULL + date_time.minute * 100ULL + date_time.second; if (fsp > 0) { Int128 value = static_cast(value_without_fsp) * 1000000 + date_time.micro_second; Decimal128 decimal(value); - return toTiDBDecimal(decimal, 6, prec, scale, context); + return toTiDBDecimal(decimal, from_scale, max_value, to_scale, scale_mul, context); } else { - return toTiDBDecimalInternal(value_without_fsp, prec, scale, context); + return toTiDBDecimalInternal(value_without_fsp, max_value, scale_mul, context); } } template - static U toTiDBDecimal(MyDate & date, PrecType prec, ScaleType scale, const Context & context) + static U toTiDBDecimal(MyDate & date, const CastInternalType & max_value, const CastInternalType & scale_mul, const Context & context) { UInt64 value = date.year * 10000 + date.month * 100 + date.day; - return toTiDBDecimalInternal(value, prec, scale, context); + return toTiDBDecimalInternal(value, max_value, scale_mul, context); } template - static std::enable_if_t, U> toTiDBDecimal(T value, PrecType prec, ScaleType scale, const Context & context) + static std::enable_if_t, U> toTiDBDecimal(T value, const CastInternalType & max_value, const CastInternalType & scale_mul, const Context & context) { if constexpr (std::is_signed_v) - return toTiDBDecimalInternal(value, prec, scale, context); + return toTiDBDecimalInternal(value, max_value, scale_mul, context); else - return toTiDBDecimalInternal(static_cast(value), prec, scale, context); + return toTiDBDecimalInternal(static_cast(value), max_value, scale_mul, context); } template @@ -916,27 +924,23 @@ struct TiDBConvertToDecimal static std::enable_if_t, U> toTiDBDecimal( const T & v, ScaleType v_scale, - PrecType prec, + const CastInternalType & max_value, ScaleType scale, + const CastInternalType & scale_mul, const Context & context) { using UType = typename U::NativeType; - auto value = Int256(v.value); + CastInternalType value = static_cast(v.value); - if (v_scale <= scale) + if (v_scale < scale) { - for (ScaleType i = v_scale; i < scale; i++) - value *= 10; + value *= scale_mul; } - else + else if (v_scale > scale) { context.getDAGContext()->handleTruncateError("cast decimal as decimal"); - bool need_to_round = false; - for (ScaleType i = scale; i < v_scale; i++) - { - need_to_round = (value < 0 ? -value : value) % 10 >= 5; - value /= 10; - } + value /= scale_mul; + const bool need_to_round = ((value < 0 ? -value : value) % scale_mul) >= (scale_mul / 2); if (need_to_round) { if (value < 0) @@ -945,17 +949,13 @@ struct TiDBConvertToDecimal ++value; } } - - auto max_value = DecimalMaxValue::get(prec); - if (value > max_value || value < -max_value) + else { - context.getDAGContext()->handleOverflowError("cast decimal as decimal", Errors::Types::Truncated); - if (value > 0) - return static_cast(max_value); - else - return static_cast(-max_value); + // If v_scale == scale, then scale_mul must be 1, no need to touch value. + assert(scale_mul == 1); } - return static_cast(value); + + return handleOverflowErrorForIntAndDecimal(context, value, max_value, "cast decimal to decimal"); } struct DecimalParts @@ -1147,8 +1147,10 @@ struct TiDBConvertToDecimal const auto * col_from = checkAndGetColumn>(block.getByPosition(arguments[0]).column.get()); const typename ColumnDecimal::Container & vec_from = col_from->getData(); + const CastInternalType max_value = getMaxValueIfNecessary(prec); + const CastInternalType scale_mul = getScaleMulForDecimalToDecimal(col_from->getScale(), scale); for (size_t i = 0; i < size; ++i) - vec_to[i] = toTiDBDecimal(vec_from[i], vec_from.getScale(), prec, scale, context); + vec_to[i] = toTiDBDecimal(vec_from[i], vec_from.getScale(), max_value, scale, scale_mul, context); } else if constexpr (std::is_same_v || std::is_same_v) { @@ -1160,17 +1162,25 @@ struct TiDBConvertToDecimal = checkAndGetColumn>(col_with_type_and_name.column.get()); const typename ColumnVector::Container & vec_from = col_from->getData(); - for (size_t i = 0; i < size; ++i) + const CastInternalType max_value = getMaxValueIfNecessary(prec); + if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) + const CastInternalType scale_mul = getScaleMultiplier(scale); + for (size_t i = 0; i < size; ++i) { MyDate date(vec_from[i]); - vec_to[i] = toTiDBDecimal(date, prec, scale, context); + vec_to[i] = toTiDBDecimal(date, max_value, scale_mul, context); } - else + } + else + { + // Check getMinPrecForHoldingDatetime() to see why from_scale is 6. + static constexpr ScaleType from_scale = 6; + const CastInternalType scale_mul = getScaleMulForDecimalToDecimal(from_scale, scale); + for (size_t i = 0; i < size; ++i) { MyDateTime date_time(vec_from[i]); - vec_to[i] = toTiDBDecimal(date_time, prec, scale, type.getFraction(), context); + vec_to[i] = toTiDBDecimal(date_time, max_value, from_scale, scale, scale_mul, type.getFraction(), context); } } } @@ -1194,11 +1204,23 @@ struct TiDBConvertToDecimal else if (const ColumnVector * col_from = checkAndGetColumn>(block.getByPosition(arguments[0]).column.get())) { - /// cast enum/int/real as decimal const typename ColumnVector::Container & vec_from = col_from->getData(); - for (size_t i = 0; i < size; ++i) - vec_to[i] = toTiDBDecimal(vec_from[i], prec, scale, context); + if constexpr (std::is_integral_v) + { + /// cast enum/int as decimal + const CastInternalType max_value = getMaxValueIfNecessary(prec); + const CastInternalType scale_mul = getScaleMultiplier(scale); + for (size_t i = 0; i < size; ++i) + vec_to[i] = toTiDBDecimal(vec_from[i], max_value, scale_mul, context); + } + else + { + static_assert(std::is_floating_point_v); + /// cast real as decimal + for (size_t i = 0; i < size; ++i) + vec_to[i] = toTiDBDecimal(vec_from[i], prec, scale, context); + } } else { @@ -1211,6 +1233,46 @@ struct TiDBConvertToDecimal else block.getByPosition(result).column = std::move(col_to); } + + template + static ReturnType handleOverflowErrorForIntAndDecimal(const Context & context, + const CastInternalType & to_value, + const CastInternalType & max_value [[maybe_unused]], + const String & msg) + { + if constexpr (!can_skip_check_overflow) + { + if (to_value > max_value || to_value < -max_value) + { + context.getDAGContext()->handleOverflowError(msg, Errors::Types::Truncated); + if (to_value > 0) + return static_cast(max_value); + else + return static_cast(-max_value); + } + } + return static_cast(to_value); + } + + // max_value is useless if can_skip_check_overflow is true. + static CastInternalType getMaxValueIfNecessary(PrecType prec [[maybe_unused]]) + { + if constexpr (!can_skip_check_overflow) + { + return static_cast(DecimalMaxValue::get(prec)); + } + else + { + return 0; + } + } + + // Only used for cast decimal to decimal. + static CastInternalType getScaleMulForDecimalToDecimal(ScaleType from_scale, ScaleType to_scale) + { + const ScaleType scale_diff = ((from_scale > to_scale) ? (from_scale - to_scale) : (to_scale - from_scale)); + return getScaleMultiplier(scale_diff); + } }; /// cast int/real/decimal/time/string as Date/DateTime @@ -1544,7 +1606,6 @@ struct TiDBConvertToDuration block.getByPosition(result).column = ColumnNullable::create(std::move(block.getByPosition(result).column), std::move(col_null_map_to)); } - constexpr static Int64 pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000}; static Int64 round(Int64 x, int fsp) { Int64 scale = pow10[fsp]; @@ -1750,6 +1811,15 @@ class FunctionTiDBCast final : public IFunctionBase return monotonicity_for_range(type, left, right); } + // rule: from_scale_prec <= to_decimal_prec + template + static bool canSkipCheckOverflowForDecimal(DataTypePtr from_type, PrecType to_decimal_prec, ScaleType to_decimal_scale) + { + constexpr bool avoid_truncate_from_value = false; + const PrecType from_scaled_prec = getMinPrecForHoldingFromValue(from_type, to_decimal_scale, avoid_truncate_from_value); + return from_scaled_prec <= to_decimal_prec; + } + private: const Context & context; const char * name; @@ -1761,80 +1831,260 @@ class FunctionTiDBCast final : public IFunctionBase bool in_union; const tipb::FieldType & tidb_tp; - /// createWrapper creates lambda functions that do the real type conversion job - template - WrapperType createWrapper(const DataTypePtr & to_type) const + template + static bool getMinPrecForHoldingInteger(DataTypePtr from_type, ScaleType to_decimal_scale, PrecType & from_scaled_prec) { - /// cast as int - if (checkDataType(to_type.get())) - return [](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { - TiDBConvertToInteger::execute( + const auto f = [&from_scaled_prec, to_decimal_scale](const auto &, bool) -> bool { + using FromFieldType = typename FromDataType::FieldType; + // This is required because other types(like Float32) don't have template specialization for IntPrec. + if constexpr (std::is_integral_v) + { + from_scaled_prec = IntPrec::prec + to_decimal_scale; + } + else + { + // Cannot reach here. castTypeToEither will return false directly. + __builtin_unreachable(); + (void)from_scaled_prec; + (void)to_decimal_scale; + } + return true; + }; + + return castTypeToEither< + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64, + DataTypeEnum8, + DataTypeEnum16>(from_type.get(), f); + } + + static PrecType getMinPrecForHoldingDecimalInternal(PrecType from_prec, ScaleType from_scale, ScaleType to_scale, bool avoid_truncate_from_value) + { + Int64 scale_diff = static_cast(to_scale) - static_cast(from_scale); + if (scale_diff < 0) + { + if (avoid_truncate_from_value) + scale_diff = 0; + else + ++scale_diff; + } + return from_prec + scale_diff; + } + + static bool getMinPrecForHoldingDatetime(DataTypePtr from_type, ScaleType to_decimal_scale, bool avoid_truncate_from_value, PrecType & from_scaled_prec) + { + if (const auto * datetime_type = checkAndGetDataType(from_type.get())) + { + const auto fsp = datetime_type->getFraction(); + if (fsp > 0) + { + // Treat datetime(fsp) as decimal(20, 6). + // Max value of datetime time is '9999-12-31 23:59:59.999999', + // which will be treated as 99991231235959.999999 when doing cast, + // so here we use 20 as its precision and 6 as its scale. + from_scaled_prec = getMinPrecForHoldingDecimalInternal(20, 6, to_decimal_scale, avoid_truncate_from_value); + } + else + { + // Max value of datetime time is '9999-12-31 23:59:59', which will be treated as 99991231235959. + // So treat it as a int value whose precision is 14. + assert(fsp == 0); + from_scaled_prec = 14 + to_decimal_scale; + } + return true; + } + return false; + } + + static bool getMinPrecForHoldingDecimal(DataTypePtr from_type, ScaleType to_decimal_scale, bool avoid_truncate_from_value, PrecType & from_scaled_prec) + { + return castTypeToEither< + DataTypeDecimal32, + DataTypeDecimal64, + DataTypeDecimal128, + DataTypeDecimal256>(from_type.get(), [to_decimal_scale, avoid_truncate_from_value, &from_scaled_prec](const auto & from_type_ptr, bool) { + from_scaled_prec = getMinPrecForHoldingDecimalInternal(from_type_ptr.getPrec(), from_type_ptr.getScale(), to_decimal_scale, avoid_truncate_from_value); + return true; + }); + } + + // Cast optimization doesn't handle float/string/duration for now, so return a max prec value. + static bool getMinPrecForHoldingOtherTypes(DataTypePtr from_type, PrecType & from_scaled_prec) + { + return castTypeToEither< + DataTypeFloat32, + DataTypeFloat64, + DataTypeString, + DataTypeMyDuration>(from_type.get(), [&from_scaled_prec](const auto &, bool) { + from_scaled_prec = std::numeric_limits::max(); + return true; + }); + } + + // The core function of the optimizations of TiDBConvertToDecimal. The basic idea has already described above. + // avoid_truncate_from_value: + // 1. True when determining CastInternalType to avoid truncating from_int_val when static_cast it to CastInternalType. + // Because it needs to hold both from_int_value and the result of multiplication(division when scale_diff is negative) of from_int_val and scale_mul. + // So if scale_diff is negative, we should reset scale_diff as zero to avoid from_int_value is truncated unexpectedly. + // 2. False when determining if can_skip_overflow_check. Also the scale_diff should plus 1 if it's negative, + // Such as cast(99.9999 as decimal(4, 2)), after division, the internal int value of cast result is 10000 instead of 9999, + // because we need to round up during cast. In this case, overflow happens. + // So we plus 1 to scale_diff(check getMinPrecForHoldingDecimalInternal). + // Then the rule will be from_prec + scale_diff + 1 <= to_prec when scale_diff < 0. + template + static PrecType getMinPrecForHoldingFromValue(DataTypePtr from_type, ScaleType to_decimal_scale, bool avoid_truncate_from_value) + { + PrecType from_scaled_prec = std::numeric_limits::max(); + + if (getMinPrecForHoldingInteger(from_type, to_decimal_scale, from_scaled_prec)) + { + // cast(int/enum as decimal) + } + else if (checkDataType(from_type.get())) + { + // cast(date as decimal) + // The max value of date type is '9999-12-31', which will be treated as 99991231 when cast it as decimal, + // so we use 8 as its precision. + from_scaled_prec = 8 + to_decimal_scale; + } + else if (getMinPrecForHoldingDatetime(from_type, to_decimal_scale, avoid_truncate_from_value, from_scaled_prec)) + { + // cast(datetime as decimal) + } + else if (getMinPrecForHoldingDecimal(from_type, to_decimal_scale, avoid_truncate_from_value, from_scaled_prec)) + { + // cast(decimal as decimal) + } + else if (getMinPrecForHoldingOtherTypes(from_type, from_scaled_prec)) + { + // cast(float/string as decimal); cast duration to decimal not pushed down for now. + } + else + { + __builtin_unreachable(); + } + return from_scaled_prec; + } + + // Determine CastInternalType template argument for TiDBConvertToDecimal, + // which is used as the type in the multiplication/division of from_int_val and scale_mul. + template + WrapperType createWrapperForDecimal(const DataTypePtr & from_type, const ToDataType * decimal_type) const + { + const bool avoid_truncate_from_value = true; + PrecType from_scaled_prec = getMinPrecForHoldingFromValue(from_type, decimal_type->getScale(), avoid_truncate_from_value); + const bool can_skip = canSkipCheckOverflowForDecimal(from_type, decimal_type->getPrec(), decimal_type->getScale()); + if (!can_skip) + { + // If cannot skip overflow check, we should use int type that can hold both scaled_val and max_val. + from_scaled_prec = std::max(from_scaled_prec, decimal_type->getPrec()); + } + + // Here we minus 1 to IntPrec::prec to avoid potential overflow. + // IntPrec denotes the minimum precision of decimals to hold a specific integer type. + // However, not all decimals with such precision could be held by this integer type. + // This could happen when calculating the max_value and the multiplication of from_int_val and scale_mul. + if (from_scaled_prec <= IntPrec::prec - 1) + { + return createWrapperForDecimal(decimal_type, can_skip); + } + else if (from_scaled_prec <= IntPrec::prec - 1) + { + return createWrapperForDecimal(decimal_type, can_skip); + } + else if (from_scaled_prec <= IntPrec::prec - 1) + { + return createWrapperForDecimal(decimal_type, can_skip); + } + else if (from_scaled_prec <= IntPrec::prec - 1) + { + return createWrapperForDecimal(decimal_type, can_skip); + } + else + { + return createWrapperForDecimal(decimal_type, can_skip); + } + } + + // Determine can_skip_overflow_check template argument for TiDBConvertToDecimal. + template + WrapperType createWrapperForDecimal(const ToDataType * decimal_type, bool can_skip) const + { + using ToFieldType = typename ToDataType::FieldType; + PrecType prec = decimal_type->getPrec(); + ScaleType scale = decimal_type->getScale(); + if (can_skip) + { + return [prec, scale](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { + TiDBConvertToDecimal::execute( block, arguments, result, + prec, + scale, in_union_, tidb_tp_, context_); }; - if (checkDataType(to_type.get())) - return [](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { - TiDBConvertToInteger::execute( + } + else + { + return [prec, scale](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { + TiDBConvertToDecimal::execute( block, arguments, result, + prec, + scale, in_union_, tidb_tp_, context_); }; - /// cast as decimal - if (const auto * decimal_type = checkAndGetDataType(to_type.get())) - return [decimal_type](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { - TiDBConvertToDecimal::execute( + } + } + + /// createWrapper creates lambda functions that do the real type conversion job + template + WrapperType createWrapper(const DataTypePtr & from_type, const DataTypePtr & to_type) const + { + /// cast as int + if (checkDataType(to_type.get())) + return [](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { + TiDBConvertToInteger::execute( block, arguments, result, - decimal_type->getPrec(), - decimal_type->getScale(), in_union_, tidb_tp_, context_); }; - if (const auto * decimal_type = checkAndGetDataType(to_type.get())) - return [decimal_type](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { - TiDBConvertToDecimal::execute( + if (checkDataType(to_type.get())) + return [](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { + TiDBConvertToInteger::execute( block, arguments, result, - decimal_type->getPrec(), - decimal_type->getScale(), in_union_, tidb_tp_, context_); }; + + /// cast as decimal + if (const auto * decimal_type = checkAndGetDataType(to_type.get())) + return createWrapperForDecimal(from_type, decimal_type); + if (const auto * decimal_type = checkAndGetDataType(to_type.get())) + return createWrapperForDecimal(from_type, decimal_type); if (const auto * decimal_type = checkAndGetDataType(to_type.get())) - return [decimal_type](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { - TiDBConvertToDecimal::execute( - block, - arguments, - result, - decimal_type->getPrec(), - decimal_type->getScale(), - in_union_, - tidb_tp_, - context_); - }; + return createWrapperForDecimal(from_type, decimal_type); if (const auto * decimal_type = checkAndGetDataType(to_type.get())) - return [decimal_type](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { - TiDBConvertToDecimal::execute( - block, - arguments, - result, - decimal_type->getPrec(), - decimal_type->getScale(), - in_union_, - tidb_tp_, - context_); - }; + return createWrapperForDecimal(from_type, decimal_type); + /// cast as real if (checkDataType(to_type.get())) return [](Block & block, const ColumnNumbers & arguments, const size_t result, bool in_union_, const tipb::FieldType & tidb_tp_, const Context & context_) { @@ -1898,7 +2148,7 @@ class FunctionTiDBCast final : public IFunctionBase context_); }; - // todo support convert to duration/json type + // todo support convert to json type throw Exception{"tidb_cast to " + to_type->getName() + " is not supported", ErrorCodes::CANNOT_CONVERT_TYPE}; } @@ -1915,45 +2165,45 @@ class FunctionTiDBCast final : public IFunctionBase if (isIdentityCast(from_type, to_type)) return createIdentityWrapper(from_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); if (checkAndGetDataType(from_type.get())) - return createWrapper(to_type); + return createWrapper(from_type, to_type); // todo support convert to duration/json type throw Exception{ diff --git a/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp b/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp index 80aa231f237..e5db30b28e8 100644 --- a/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp +++ b/dbms/src/Functions/tests/gtest_arithmetic_functions.cpp @@ -818,7 +818,7 @@ try // Int64 has a precision of 20, which is larger than Decimal64. ASSERT_COLUMN_EQ( createColumn>( - std::make_tuple(20, 3), + std::make_tuple(19, 3), {DecimalField128(300, 3), {}, {}}), executeFunction( func_name, @@ -980,8 +980,8 @@ try MODULO_TESTCASE(Int64, Int64, Int64, 0, 0, 0, 0); MODULO_TESTCASE(Int64, UInt64, Int64, 0, 0, 0, 0); MODULO_TESTCASE(Int64, Float64, Float64, 0, 0, 0, 0); - MODULO_TESTCASE(Int64, Decimal32, Decimal128, 20, 0, 3, 3); - MODULO_TESTCASE(Int64, Decimal64, Decimal128, 20, 0, 6, 6); + MODULO_TESTCASE(Int64, Decimal32, Decimal128, 19, 0, 3, 3); + MODULO_TESTCASE(Int64, Decimal64, Decimal128, 19, 0, 6, 6); MODULO_TESTCASE(Int64, Decimal128, Decimal128, 38, 0, 10, 10); MODULO_TESTCASE(Int64, Decimal256, Decimal256, 65, 0, 20, 20); @@ -1001,7 +1001,7 @@ try MODULO_TESTCASE(Float64, Decimal128, Float64, 0, 0, 10, 0); MODULO_TESTCASE(Float64, Decimal256, Float64, 0, 0, 20, 0); - MODULO_TESTCASE(Decimal32, Int64, Decimal128, 20, 3, 0, 3); + MODULO_TESTCASE(Decimal32, Int64, Decimal128, 19, 3, 0, 3); MODULO_TESTCASE(Decimal32, UInt64, Decimal128, 20, 3, 0, 3); MODULO_TESTCASE(Decimal32, Float64, Float64, 0, 3, 0, 0); MODULO_TESTCASE(Decimal32, Decimal32, Decimal32, 9, 3, 3, 3); @@ -1009,7 +1009,7 @@ try MODULO_TESTCASE(Decimal32, Decimal128, Decimal128, 38, 3, 10, 10); MODULO_TESTCASE(Decimal32, Decimal256, Decimal256, 65, 3, 20, 20); - MODULO_TESTCASE(Decimal64, Int64, Decimal128, 20, 6, 0, 6); + MODULO_TESTCASE(Decimal64, Int64, Decimal128, 19, 6, 0, 6); MODULO_TESTCASE(Decimal64, UInt64, Decimal128, 20, 6, 0, 6); MODULO_TESTCASE(Decimal64, Float64, Float64, 0, 6, 0, 0); MODULO_TESTCASE(Decimal64, Decimal32, Decimal64, 18, 6, 3, 6); diff --git a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp index e9690fd2191..2b21778b8a5 100644 --- a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp +++ b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp @@ -1,15 +1,16 @@ -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -#include "Columns/ColumnsNumber.h" -#include "Core/ColumnWithTypeAndName.h" -#include "DataTypes/DataTypeMyDateTime.h" -#include "DataTypes/DataTypeMyDuration.h" -#include "DataTypes/DataTypeNullable.h" -#include "DataTypes/DataTypesNumber.h" -#include "Functions/FunctionHelpers.h" -#include "TestUtils/FunctionTestUtils.h" -#include "common/types.h" -#include "gtest/gtest.h" +#include namespace DB::tests { @@ -1426,6 +1427,268 @@ try } CATCH +TEST_F(TestTidbConversion, skipCheckOverflowIntToDeciaml) +{ + DataTypePtr int8_ptr = makeDataType(); + DataTypePtr int16_ptr = makeDataType(); + DataTypePtr int32_ptr = makeDataType(); + DataTypePtr int64_ptr = makeDataType(); + DataTypePtr uint8_ptr = makeDataType(); + DataTypePtr uint16_ptr = makeDataType(); + DataTypePtr uint32_ptr = makeDataType(); + DataTypePtr uint64_ptr = makeDataType(); + + const PrecType prec_decimal32 = 8; + const PrecType prec_decimal64 = 17; + const PrecType prec_decimal128 = 37; + const PrecType prec_decimal256 = 65; + const ScaleType scale = 0; + + // int8(max_prec: 3) -> decimal32(max_prec: 9) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal32, scale)); + // int16(max_prec: 5) -> decimal32(max_prec: 9) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal32, scale)); + // int32(max_prec: 10) -> decimal32(max_prec: 9) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal32, scale)); + // int64(max_prec: 20) -> decimal32(max_prec: 9) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal32, scale)); + + // uint8(max_prec: 3) -> decimal32(max_prec: 9) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal32, scale)); + // uint16(max_prec: 5) -> decimal32(max_prec: 9) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal32, scale)); + // uint32(max_prec: 10) -> decimal32(max_prec: 9) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal32, scale)); + // uint64(max_prec: 20) -> decimal32(max_prec: 9) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal32, scale)); + + // int8(max_prec: 3) -> decimal64(max_prec: 18) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal64, scale)); + // int16(max_prec: 5) -> decimal64(max_prec: 18) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal64, scale)); + // int32(max_prec: 10) -> decimal64(max_prec: 18) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal64, scale)); + // int64(max_prec: 20) -> decimal64(max_prec: 18) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal64, scale)); + + // uint8(max_prec: 3) -> decimal64(max_prec: 18) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal64, scale)); + // uint16(max_prec: 5) -> decimal64(max_prec: 18) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal64, scale)); + // uint32(max_prec: 10) -> decimal64(max_prec: 18) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal64, scale)); + // uint64(max_prec: 20) -> decimal64(max_prec: 18) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal64, scale)); + + // int8(max_prec: 3) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal128, scale)); + // int16(max_prec: 5) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal128, scale)); + // int32(max_prec: 10) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal128, scale)); + // int64(max_prec: 20) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal128, scale)); + + // uint8(max_prec: 3) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal128, scale)); + // uint16(max_prec: 5) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal128, scale)); + // uint32(max_prec: 10) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal128, scale)); + // uint64(max_prec: 20) -> decimal128(max_prec: 38) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal128, scale)); + + // int8(max_prec: 3) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal256, scale)); + // int16(max_prec: 5) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal256, scale)); + // int32(max_prec: 10) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal256, scale)); + // int64(max_prec: 20) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal256, scale)); + + // uint8(max_prec: 3) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal256, scale)); + // uint16(max_prec: 5) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal256, scale)); + // uint32(max_prec: 10) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal256, scale)); + // uint64(max_prec: 20) -> decimal256(max_prec: 65) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal256, scale)); +} + +TEST_F(TestTidbConversion, skipCheckOverflowDecimalToDeciaml) +{ + DataTypePtr decimal32_ptr_8_3 = createDecimal(8, 3); + DataTypePtr decimal32_ptr_8_2 = createDecimal(8, 2); + + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 8, 3)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_3, 8, 2)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 7, 5)); + + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 9, 3)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 9, 1)); + + DataTypePtr decimal32_ptr_6_4 = createDecimal(6, 4); + // decimal(6, 4) -> decimal(5, 3) + // because select cast(99.9999 as decimal(5, 3)); -> 100.000 is greater than 99.999. + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 5, 3)); + // decimal(6, 4) -> decimal(7, 5) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 7, 5)); + + // decimal(6, 4) -> decimal(6, 5) + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 6, 5)); + // decimal(6, 4) -> decimal(8, 5) + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 8, 5)); +} + +TEST_F(TestTidbConversion, skipCheckOverflowEnumToDecimal) +{ + DataTypeEnum8::Values enum8_values; + enum8_values.push_back({"a", 10}); + enum8_values.push_back({"b", 20}); + DataTypePtr enum8_ptr = std::make_shared(enum8_values); + + DataTypeEnum16::Values enum16_values; + enum16_values.push_back({"a1", 1000}); + enum16_values.push_back({"b1", 2000}); + DataTypePtr enum16_ptr = std::make_shared(enum16_values); + + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 3, 0)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 4, 1)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 2, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 4, 2)); + + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 5, 0)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 6, 1)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 4, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 6, 2)); +} + +TEST_F(TestTidbConversion, skipCheckOverflowMyDateTimeToDeciaml) +{ + DataTypePtr datetime_ptr_no_fsp = std::make_shared(); + DataTypePtr datetime_ptr_fsp_5 = std::make_shared(5); + + // rule for no fsp: 14 + to_scale <= to_prec. + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 5, 3)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 18, 3)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 17, 3)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 18, 4)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 14, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 14, 1)); + + // rule for fsp: 20 + scale_diff <= to_prec. + // 20 + (3 - 6 + 1) = 18 + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 19, 3)); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 18, 3)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 17, 3)); +} + +TEST_F(TestTidbConversion, skipCheckOverflowMyDateToDeciaml) +{ + DataTypePtr date_ptr = std::make_shared(); + + // rule: 8 + to_scale <= to_prec. + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(date_ptr, 11, 3)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(date_ptr, 11, 4)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(date_ptr, 10, 3)); +} + +TEST_F(TestTidbConversion, skipCheckOverflowOtherToDecimal) +{ + // float and string not support skip overflow check. + DataTypePtr string_ptr = std::make_shared(); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(string_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(string_ptr, 60, 1)); + + DataTypePtr float32_ptr = std::make_shared(); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float32_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float32_ptr, 60, 1)); + + DataTypePtr float64_ptr = std::make_shared(); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float64_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float64_ptr, 60, 1)); + + // cast duration to decimal is not supported to push down to tiflash for now. + DataTypePtr duration_ptr = std::make_shared(); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(duration_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(duration_ptr, 60, 1)); +} + +// check if template argument of CastInternalType is correct or not. +TEST_F(TestTidbConversion, checkCastInternalType) +try +{ + // case1: cast(tinyint as decimal(7, 3)) + PrecType to_prec = 7; + ScaleType to_scale = 3; + DataTypePtr int8_ptr = std::make_shared(); + // from_prec(3) + to_scale(3) <= Decimal32::prec(9), so we **CAN** skip check overflow. + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, to_prec, to_scale)); + + // from_prec(3) + to_scale(3) <= Int32::real_prec(10) - 1, so CastInternalType should be **Int32**. + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(to_prec, to_scale), + {DecimalField32(MAX_INT8 * 1000, to_scale), DecimalField32(MIN_INT8 * 1000, to_scale), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(7,3))")})); + + // case2: cast(tinyint as decimal(9, 7)) + to_prec = 9; + to_scale = 7; + // from_prec(3) + to_scale(7) > Decimal32::prec(9), so we **CANNOT** skip check overflow. + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, to_prec, to_scale)); + + // from_prec(3) + to_scale(7) > Int32::real_prec(10) - 1, so CastInternalType should be **Int64**. + DAGContext * dag_context = context.getDAGContext(); + UInt64 ori_flags = dag_context->getFlags(); + dag_context->addFlag(TiDBSQLFlags::OVERFLOW_AS_WARNING); + dag_context->clearWarnings(); + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(to_prec, to_scale), + {DecimalField32(999999999, to_scale), DecimalField32(-999999999, to_scale), {}}), + executeFunction(func_name, + {createColumn>({MAX_INT8, MIN_INT8, {}}), + createCastTypeConstColumn("Nullable(Decimal(9,7))")})); + dag_context->setFlags(ori_flags); + + // case3: cast(bigint as decimal(40, 20)) + // from_prec(19) + to_scale(20) <= Decimal256::prec(40), so we **CAN** skip check overflow. + to_prec = 40; + to_scale = 20; + DataTypePtr int64_ptr = std::make_shared(); + ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, to_prec, to_scale)); + + // from_prec(19) + to_scale(20) > Int128::real_prec(39) - 1, so CastInternalType should be **Int256**. + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(to_prec, to_scale), + {DecimalField256(1024 * static_cast(pow(10, to_scale)), to_scale), DecimalField256(-1024 * static_cast(pow(10, to_scale)), to_scale), {}}), + executeFunction(func_name, + {createColumn>({1024, -1024, {}}), + createCastTypeConstColumn("Nullable(Decimal(40,20))")})); + + // case4: cast(bigint as decimal(38, 20)) + // from_prec(19) + to_scale(20) > Decimal256::prec(38), so we **CANNOT** skip check overflow. + to_prec = 38; + to_scale = 20; + ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, to_prec, to_scale)); + + // from_prec(19) + to_scale(20) > Int128::real_prec(39) - 1, so CastInternalType should be **Int256**. + ASSERT_COLUMN_EQ( + createColumn>( + std::make_tuple(to_prec, to_scale), + {DecimalField128(1024 * static_cast(pow(10, to_scale)), to_scale), DecimalField128(-1024 * static_cast(pow(10, to_scale)), to_scale), {}}), + executeFunction(func_name, + {createColumn>({1024, -1024, {}}), + createCastTypeConstColumn("Nullable(Decimal(38,20))")})); +} +CATCH + // for https://github.com/pingcap/tics/issues/4036 TEST_F(TestTidbConversion, castStringAsDateTime) try diff --git a/tests/fullstack-test/expr/cast_decimal_overflow.test b/tests/fullstack-test/expr/cast_decimal_overflow.test new file mode 100644 index 00000000000..968bf0db22a --- /dev/null +++ b/tests/fullstack-test/expr/cast_decimal_overflow.test @@ -0,0 +1,195 @@ +mysql> drop table if exists test.t1; +mysql> create table test.t1(c1 decimal(6, 4)); +mysql> insert into test.t1 values(99.9999); +mysql> alter table test.t1 set tiflash replica 1; +func> wait_table test t1 +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(60, 3)) from test.t1; +cast(test.t1.c1 as decimal(60, 3)) +100.000 +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(5, 3)) from test.t1; +cast(test.t1.c1 as decimal(5, 3)) +99.999 +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(7, 5)) from test.t1; +cast(test.t1.c1 as decimal(7, 5)) +99.99990 +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(6, 5)) from test.t1; +cast(test.t1.c1 as decimal(6, 5)) +9.99999 + +mysql> drop table if exists test.t1; +mysql> create table test.t1(tiny_col tinyint, tiny_unsigned_col tinyint unsigned, small_col smallint, small_unsigned_col smallint unsigned, med_col mediumint, med_unsigned_col mediumint unsigned, int_col int, int_unsigned_col int unsigned, big_col bigint, big_unsigned_col bigint unsigned); +mysql> insert into test.t1 values(127, 255, 32767, 65535, 8388607, 16777215, 2147483647, 4294967295, 9223372036854775807, 18446744073709551615); +mysql> alter table test.t1 set tiflash replica 1; +func> wait_table test t1 + +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.tiny_col as decimal(3, 0)) from test.t1; ++-----------------------------------------+ +| cast(test.t1.tiny_col as decimal(3, 0)) | ++-----------------------------------------+ +| 127 | ++-----------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.tiny_col as decimal(3, 1)) from test.t1; ++-----------------------------------------+ +| cast(test.t1.tiny_col as decimal(3, 1)) | ++-----------------------------------------+ +| 99.9 | ++-----------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.tiny_unsigned_col as decimal(3, 0)) from test.t1; ++--------------------------------------------------+ +| cast(test.t1.tiny_unsigned_col as decimal(3, 0)) | ++--------------------------------------------------+ +| 255 | ++--------------------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.tiny_unsigned_col as decimal(3, 1)) from test.t1; ++--------------------------------------------------+ +| cast(test.t1.tiny_unsigned_col as decimal(3, 1)) | ++--------------------------------------------------+ +| 99.9 | ++--------------------------------------------------+ + +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.small_col as decimal(5, 0)) from test.t1; ++------------------------------------------+ +| cast(test.t1.small_col as decimal(5, 0)) | ++------------------------------------------+ +| 32767 | ++------------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.small_col as decimal(5, 1)) from test.t1; ++------------------------------------------+ +| cast(test.t1.small_col as decimal(5, 1)) | ++------------------------------------------+ +| 9999.9 | ++------------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.small_unsigned_col as decimal(5, 0)) from test.t1; ++---------------------------------------------------+ +| cast(test.t1.small_unsigned_col as decimal(5, 0)) | ++---------------------------------------------------+ +| 65535 | ++---------------------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.small_unsigned_col as decimal(5, 1)) from test.t1; ++---------------------------------------------------+ +| cast(test.t1.small_unsigned_col as decimal(5, 1)) | ++---------------------------------------------------+ +| 9999.9 | ++---------------------------------------------------+ + +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.med_col as decimal(10, 0)) from test.t1; ++-----------------------------------------+ +| cast(test.t1.med_col as decimal(10, 0)) | ++-----------------------------------------+ +| 8388607 | ++-----------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.med_col as decimal(10, 1)) from test.t1; ++-----------------------------------------+ +| cast(test.t1.med_col as decimal(10, 1)) | ++-----------------------------------------+ +| 8388607.0 | ++-----------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.med_unsigned_col as decimal(10, 0)) from test.t1; ++--------------------------------------------------+ +| cast(test.t1.med_unsigned_col as decimal(10, 0)) | ++--------------------------------------------------+ +| 16777215 | ++--------------------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.med_unsigned_col as decimal(10, 1)) from test.t1; ++--------------------------------------------------+ +| cast(test.t1.med_unsigned_col as decimal(10, 1)) | ++--------------------------------------------------+ +| 16777215.0 | ++--------------------------------------------------+ + +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.big_col as decimal(20, 0)) from test.t1; ++-----------------------------------------+ +| cast(test.t1.big_col as decimal(20, 0)) | ++-----------------------------------------+ +| 9223372036854775807 | ++-----------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.big_col as decimal(20, 1)) from test.t1; ++-----------------------------------------+ +| cast(test.t1.big_col as decimal(20, 1)) | ++-----------------------------------------+ +| 9223372036854775807.0 | ++-----------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.big_unsigned_col as decimal(20, 0)) from test.t1; ++--------------------------------------------------+ +| cast(test.t1.big_unsigned_col as decimal(20, 0)) | ++--------------------------------------------------+ +| 18446744073709551615 | ++--------------------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set tidb_enforce_mpp = 1; select cast(test.t1.big_unsigned_col as decimal(20, 1)) from test.t1; ++--------------------------------------------------+ +| cast(test.t1.big_unsigned_col as decimal(20, 1)) | ++--------------------------------------------------+ +| 9999999999999999999.9 | ++--------------------------------------------------+ + +mysql> drop table if exists test.t1; +mysql> create table test.t1(c1 decimal(50, 0)); +mysql> insert into test.t1 values(12345678901234567890123456789012345678901234567890); +mysql> alter table test.t1 set tiflash replica 1; +func> wait_table test t1 +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(c1 as decimal(35, 30)) from test.t1; ++--------------------------------------+ +| cast(c1 as decimal(35, 30)) | ++--------------------------------------+ +| 99999.999999999999999999999999999999 | ++--------------------------------------+ + +mysql> drop table if exists test.t1; +mysql> create table test.t1(c1 datetime(5)); +mysql> insert into test.t1 values('2022-10-10 10:10:10.12345'); +mysql> alter table test.t1 set tiflash replica 1; +func> wait_table test t1 +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(16, 3)) from test.t1; ++------------------------------------+ +| cast(test.t1.c1 as decimal(16, 3)) | ++------------------------------------+ +| 9999999999999.999 | ++------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(17, 3)) from test.t1; ++------------------------------------+ +| cast(test.t1.c1 as decimal(17, 3)) | ++------------------------------------+ +| 20221010101010.123 | ++------------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(18, 3)) from test.t1; ++------------------------------------+ +| cast(test.t1.c1 as decimal(18, 3)) | ++------------------------------------+ +| 20221010101010.123 | ++------------------------------------+ + +mysql> drop table if exists test.t1; +mysql> create table test.t1(c1 date); +mysql> insert into test.t1 values('2020-10-10'); +mysql> alter table test.t1 set tiflash replica 1; +func> wait_table test t1 +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(7, 0)) from test.t1; ++-----------------------------------+ +| cast(test.t1.c1 as decimal(7, 0)) | ++-----------------------------------+ +| 9999999 | ++-----------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(8, 0)) from test.t1; ++-----------------------------------+ +| cast(test.t1.c1 as decimal(8, 0)) | ++-----------------------------------+ +| 20201010 | ++-----------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(9, 0)) from test.t1; ++-----------------------------------+ +| cast(test.t1.c1 as decimal(9, 0)) | ++-----------------------------------+ +| 20201010 | ++-----------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(9, 1)) from test.t1; ++-----------------------------------+ +| cast(test.t1.c1 as decimal(9, 1)) | ++-----------------------------------+ +| 20201010.0 | ++-----------------------------------+ +mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(test.t1.c1 as decimal(9, 2)) from test.t1; ++-----------------------------------+ +| cast(test.t1.c1 as decimal(9, 2)) | ++-----------------------------------+ +| 9999999.99 | ++-----------------------------------+