From 1467ca5610ade5fd9d4acc5a2a4efc3fb324c7e4 Mon Sep 17 00:00:00 2001 From: chertus Date: Thu, 23 Aug 2018 22:11:31 +0300 Subject: [PATCH] decimal field visitors CLICKHOUSE-3765 --- dbms/src/Common/FieldVisitors.cpp | 36 +++++ dbms/src/Common/FieldVisitors.h | 174 ++++++++++++++++------- dbms/src/Core/Field.cpp | 15 ++ dbms/src/Core/Field.h | 26 +++- dbms/src/DataTypes/DataTypesDecimal.cpp | 29 +--- dbms/src/DataTypes/DataTypesDecimal.h | 2 - dbms/src/DataTypes/FieldToDataType.cpp | 19 +++ dbms/src/DataTypes/FieldToDataType.h | 3 + dbms/src/Functions/FunctionsConversion.h | 2 +- dbms/src/IO/WriteHelpers.h | 31 ++++ 10 files changed, 253 insertions(+), 84 deletions(-) diff --git a/dbms/src/Common/FieldVisitors.cpp b/dbms/src/Common/FieldVisitors.cpp index 62b7667d936a..ea82f751fb19 100644 --- a/dbms/src/Common/FieldVisitors.cpp +++ b/dbms/src/Common/FieldVisitors.cpp @@ -30,11 +30,22 @@ static inline String formatQuotedWithPrefix(T x, const char * prefix) return wb.str(); } +template +static inline void writeQuoted(const DecimalField & x, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeText(x.getValue(), x.getScale(), buf); + writeChar('\'', buf); +} + String FieldVisitorDump::operator() (const Null &) const { return "NULL"; } String FieldVisitorDump::operator() (const UInt64 & x) const { return formatQuotedWithPrefix(x, "UInt64_"); } String FieldVisitorDump::operator() (const Int64 & x) const { return formatQuotedWithPrefix(x, "Int64_"); } String FieldVisitorDump::operator() (const Float64 & x) const { return formatQuotedWithPrefix(x, "Float64_"); } +String FieldVisitorDump::operator() (const DecimalField & x) const { return formatQuotedWithPrefix(x, "Decimal32_"); } +String FieldVisitorDump::operator() (const DecimalField & x) const { return formatQuotedWithPrefix(x, "Decimal64_"); } +String FieldVisitorDump::operator() (const DecimalField & x) const { return formatQuotedWithPrefix(x, "Decimal128_"); } String FieldVisitorDump::operator() (const UInt128 & x) const { @@ -112,6 +123,9 @@ String FieldVisitorToString::operator() (const UInt64 & x) const { return format String FieldVisitorToString::operator() (const Int64 & x) const { return formatQuoted(x); } String FieldVisitorToString::operator() (const Float64 & x) const { return formatFloat(x); } String FieldVisitorToString::operator() (const String & x) const { return formatQuoted(x); } +String FieldVisitorToString::operator() (const DecimalField & x) const { return formatQuoted(x); } +String FieldVisitorToString::operator() (const DecimalField & x) const { return formatQuoted(x); } +String FieldVisitorToString::operator() (const DecimalField & x) const { return formatQuoted(x); } String FieldVisitorToString::operator() (const UInt128 & x) const { @@ -207,4 +221,26 @@ void FieldVisitorHash::operator() (const Array & x) const applyVisitor(*this, elem); } +void FieldVisitorHash::operator() (const DecimalField & x) const +{ + UInt8 type = Field::Types::Decimal32; + hash.update(type); + hash.update(x); +} + +void FieldVisitorHash::operator() (const DecimalField & x) const +{ + UInt8 type = Field::Types::Decimal64; + hash.update(type); + hash.update(x); +} + +void FieldVisitorHash::operator() (const DecimalField & x) const +{ + UInt8 type = Field::Types::Decimal128; + hash.update(type); + hash.update(x); +} + + } diff --git a/dbms/src/Common/FieldVisitors.h b/dbms/src/Common/FieldVisitors.h index 8abf75dbc64b..fd79c2bcd84d 100644 --- a/dbms/src/Common/FieldVisitors.h +++ b/dbms/src/Common/FieldVisitors.h @@ -44,6 +44,9 @@ typename std::decay_t::ResultType applyVisitor(Visitor && visitor, F && case Field::Types::String: return visitor(field.template get()); case Field::Types::Array: return visitor(field.template get()); case Field::Types::Tuple: return visitor(field.template get()); + case Field::Types::Decimal32: return visitor(field.template get>()); + case Field::Types::Decimal64: return visitor(field.template get>()); + case Field::Types::Decimal128: return visitor(field.template get>()); default: throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -64,6 +67,9 @@ static typename std::decay_t::ResultType applyBinaryVisitorImpl(Visitor case Field::Types::String: return visitor(field1, field2.template get()); case Field::Types::Array: return visitor(field1, field2.template get()); case Field::Types::Tuple: return visitor(field1, field2.template get()); + case Field::Types::Decimal32: return visitor(field1, field2.template get>()); + case Field::Types::Decimal64: return visitor(field1, field2.template get>()); + case Field::Types::Decimal128: return visitor(field1, field2.template get>()); default: throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -99,6 +105,15 @@ typename std::decay_t::ResultType applyVisitor(Visitor && visitor, F1 & case Field::Types::Tuple: return applyBinaryVisitorImpl( std::forward(visitor), field1.template get(), std::forward(field2)); + case Field::Types::Decimal32: + return applyBinaryVisitorImpl( + std::forward(visitor), field1.template get>(), std::forward(field2)); + case Field::Types::Decimal64: + return applyBinaryVisitorImpl( + std::forward(visitor), field1.template get>(), std::forward(field2)); + case Field::Types::Decimal128: + return applyBinaryVisitorImpl( + std::forward(visitor), field1.template get>(), std::forward(field2)); default: throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -118,6 +133,9 @@ class FieldVisitorToString : public StaticVisitor String operator() (const String & x) const; String operator() (const Array & x) const; String operator() (const Tuple & x) const; + String operator() (const DecimalField & x) const; + String operator() (const DecimalField & x) const; + String operator() (const DecimalField & x) const; }; @@ -133,6 +151,9 @@ class FieldVisitorDump : public StaticVisitor String operator() (const String & x) const; String operator() (const Array & x) const; String operator() (const Tuple & x) const; + String operator() (const DecimalField & x) const; + String operator() (const DecimalField & x) const; + String operator() (const DecimalField & x) const; }; @@ -169,6 +190,15 @@ class FieldVisitorConvertToNumber : public StaticVisitor { throw Exception("Cannot convert UInt128 to " + demangle(typeid(T).name()), ErrorCodes::CANNOT_CONVERT_TYPE); } + + template + T operator() (const DecimalField & x) const + { + if constexpr (std::is_floating_point_v) + return static_cast(x.getValue()) / x.getScaleMultiplier(); + else + return x.getValue() / x.getScaleMultiplier(); + } }; @@ -187,9 +217,18 @@ class FieldVisitorHash : public StaticVisitor<> void operator() (const Float64 & x) const; void operator() (const String & x) const; void operator() (const Array & x) const; + void operator() (const DecimalField & x) const; + void operator() (const DecimalField & x) const; + void operator() (const DecimalField & x) const; }; +template constexpr bool isDecimalField() { return false; } +template <> constexpr bool isDecimalField>() { return true; } +template <> constexpr bool isDecimalField>() { return true; } +template <> constexpr bool isDecimalField>() { return true; } + + /** More precise comparison, used for index. * Differs from Field::operator< and Field::operator== in that it also compares values of different types. * Comparison rules are same as in FunctionsComparison (to be consistent with expression evaluation in query). @@ -199,15 +238,6 @@ class FieldVisitorHash : public StaticVisitor<> class FieldVisitorAccurateEquals : public StaticVisitor { public: - bool operator() (const Null &, const Null &) const { return true; } - bool operator() (const Null &, const UInt64 &) const { return false; } - bool operator() (const Null &, const UInt128 &) const { return false; } - bool operator() (const Null &, const Int64 &) const { return false; } - bool operator() (const Null &, const Float64 &) const { return false; } - bool operator() (const Null &, const String &) const { return false; } - bool operator() (const Null &, const Array &) const { return false; } - bool operator() (const Null &, const Tuple &) const { return false; } - bool operator() (const UInt64 &, const Null &) const { return false; } bool operator() (const UInt64 & l, const UInt64 & r) const { return l == r; } bool operator() (const UInt64 &, const UInt128) const { return true; } @@ -253,37 +283,49 @@ class FieldVisitorAccurateEquals : public StaticVisitor bool operator() (const String &, const Array &) const { return false; } bool operator() (const String &, const Tuple &) const { return false; } - bool operator() (const Array &, const Null &) const { return false; } - bool operator() (const Array &, const UInt64 &) const { return false; } - bool operator() (const Array &, const UInt128 &) const { return false; } - bool operator() (const Array &, const Int64 &) const { return false; } - bool operator() (const Array &, const Float64 &) const { return false; } - bool operator() (const Array &, const String &) const { return false; } - bool operator() (const Array & l, const Array & r) const { return l == r; } - bool operator() (const Array &, const Tuple &) const { return false; } - - bool operator() (const Tuple &, const Null &) const { return false; } - bool operator() (const Tuple &, const UInt64 &) const { return false; } - bool operator() (const Tuple &, const UInt128 &) const { return false; } - bool operator() (const Tuple &, const Int64 &) const { return false; } - bool operator() (const Tuple &, const Float64 &) const { return false; } - bool operator() (const Tuple &, const String &) const { return false; } - bool operator() (const Tuple &, const Array &) const { return false; } - bool operator() (const Tuple & l, const Tuple & r) const { return l == r; } + template + bool operator() (const Null &, const T &) const + { + return std::is_same_v; + } + + template + bool operator() (const Array & l, const T & r) const + { + if constexpr (std::is_same_v) + return l == r; + return false; + } + + template + bool operator() (const Tuple & l, const T & r) const + { + if constexpr (std::is_same_v) + return l == r; + return false; + } + + template + bool operator() (const DecimalField & l, const U & r) const + { + if constexpr (isDecimalField()) + return l == r; + else if constexpr (std::is_same_v || std::is_same_v) + return l == DecimalField(r, 0); + return false; + } + + template bool operator() (const UInt64 & l, const DecimalField & r) const { return DecimalField(l, 0) == r; } + template bool operator() (const UInt128 &, const DecimalField &) const { return false; } + template bool operator() (const Int64 & l, const DecimalField & r) const { return DecimalField(l, 0) == r; } + template bool operator() (const Float64 &, const DecimalField &) const { return false; } + template bool operator() (const String &, const DecimalField &) const { return false; } }; + class FieldVisitorAccurateLess : public StaticVisitor { public: - bool operator() (const Null &, const Null &) const { return false; } - bool operator() (const Null &, const UInt64 &) const { return true; } - bool operator() (const Null &, const Int64 &) const { return true; } - bool operator() (const Null &, const UInt128 &) const { return true; } - bool operator() (const Null &, const Float64 &) const { return true; } - bool operator() (const Null &, const String &) const { return true; } - bool operator() (const Null &, const Array &) const { return true; } - bool operator() (const Null &, const Tuple &) const { return true; } - bool operator() (const UInt64 &, const Null &) const { return false; } bool operator() (const UInt64 & l, const UInt64 & r) const { return l < r; } bool operator() (const UInt64 &, const UInt128 &) const { return true; } @@ -329,25 +371,46 @@ class FieldVisitorAccurateLess : public StaticVisitor bool operator() (const String &, const Array &) const { return true; } bool operator() (const String &, const Tuple &) const { return true; } - bool operator() (const Array &, const Null &) const { return false; } - bool operator() (const Array &, const UInt64 &) const { return false; } - bool operator() (const Array &, const UInt128 &) const { return false; } - bool operator() (const Array &, const Int64 &) const { return false; } - bool operator() (const Array &, const Float64 &) const { return false; } - bool operator() (const Array &, const String &) const { return false; } - bool operator() (const Array & l, const Array & r) const { return l < r; } - bool operator() (const Array &, const Tuple &) const { return false; } - - bool operator() (const Tuple &, const Null &) const { return false; } - bool operator() (const Tuple &, const UInt64 &) const { return false; } - bool operator() (const Tuple &, const UInt128 &) const { return false; } - bool operator() (const Tuple &, const Int64 &) const { return false; } - bool operator() (const Tuple &, const Float64 &) const { return false; } - bool operator() (const Tuple &, const String &) const { return false; } - bool operator() (const Tuple &, const Array &) const { return false; } - bool operator() (const Tuple & l, const Tuple & r) const { return l < r; } + template + bool operator() (const Null &, const T &) const + { + return !std::is_same_v; + } + + template + bool operator() (const Array & l, const T & r) const + { + if constexpr (std::is_same_v) + return l < r; + return false; + } + + template + bool operator() (const Tuple & l, const T & r) const + { + if constexpr (std::is_same_v) + return l < r; + return false; + } + + template + bool operator() (const DecimalField & l, const U & r) const + { + if constexpr (isDecimalField()) + return l < r; + else if constexpr (std::is_same_v || std::is_same_v) + return l < DecimalField(r, 0); + return false; + } + + template bool operator() (const UInt64 & l, const DecimalField & r) const { return DecimalField(l, 0) < r; } + template bool operator() (const UInt128 &, const DecimalField &) const { return false; } + template bool operator() (const Int64 & l, const DecimalField & r) const { return DecimalField(l, 0) < r; } + template bool operator() (const Float64 &, const DecimalField &) const { return false; } + template bool operator() (const String &, const DecimalField &) const { return false; } }; + /** Implements `+=` operation. * Returns false if the result is zero. */ @@ -366,6 +429,13 @@ class FieldVisitorSum : public StaticVisitor bool operator() (String &) const { throw Exception("Cannot sum Strings", ErrorCodes::LOGICAL_ERROR); } bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); } bool operator() (UInt128 &) const { throw Exception("Cannot sum UUIDs", ErrorCodes::LOGICAL_ERROR); } + + template + bool operator() (DecimalField & x) const + { + x += get>(rhs); + return x.getValue() != 0; + } }; } diff --git a/dbms/src/Core/Field.cpp b/dbms/src/Core/Field.cpp index eb6278adc956..9d20a89eb5c9 100644 --- a/dbms/src/Core/Field.cpp +++ b/dbms/src/Core/Field.cpp @@ -273,6 +273,21 @@ namespace DB } + template <> Decimal32 DecimalField::getScaleMultiplier() const + { + return DataTypeDecimal::getScaleMultiplier(scale); + } + + template <> Decimal64 DecimalField::getScaleMultiplier() const + { + return DataTypeDecimal::getScaleMultiplier(scale); + } + + template <> Decimal128 DecimalField::getScaleMultiplier() const + { + return DataTypeDecimal::getScaleMultiplier(scale); + } + template static bool decEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) { diff --git a/dbms/src/Core/Field.h b/dbms/src/Core/Field.h index 31062e7d66bb..bbe79e6285c2 100644 --- a/dbms/src/Core/Field.h +++ b/dbms/src/Core/Field.h @@ -20,6 +20,7 @@ namespace ErrorCodes extern const int BAD_TYPE_OF_FIELD; extern const int BAD_GET; extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; } class Field; @@ -43,34 +44,51 @@ class DecimalField {} operator T() const { return dec; } - + T getValue() const { return dec; } + T getScaleMultiplier() const; UInt32 getScale() const { return scale; } template bool operator < (const DecimalField & r) const { using MaxType = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; - return decimalLess(dec, r, scale, r.getScale()); + return decimalLess(dec, r.getValue(), scale, r.getScale()); } template bool operator <= (const DecimalField & r) const { using MaxType = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; - return decimalLessOrEqual(dec, r, scale, r.getScale()); + return decimalLessOrEqual(dec, r.getValue(), scale, r.getScale()); } template bool operator == (const DecimalField & r) const { using MaxType = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; - return decimalEqual(dec, r, scale, r.getScale()); + return decimalEqual(dec, r.getValue(), scale, r.getScale()); } template bool operator > (const DecimalField & r) const { return r < *this; } template bool operator >= (const DecimalField & r) const { return r <= * this; } template bool operator != (const DecimalField & r) const { return !(*this == r); } + const DecimalField & operator += (const DecimalField & r) + { + if (scale != r.getScale()) + throw Exception("Add different decimal fields", ErrorCodes::LOGICAL_ERROR); + dec += r.getValue(); + return *this; + } + + const DecimalField & operator -= (const DecimalField & r) + { + if (scale != r.getScale()) + throw Exception("Sub different decimal fields", ErrorCodes::LOGICAL_ERROR); + dec -= r.getValue(); + return *this; + } + private: T dec; UInt32 scale; diff --git a/dbms/src/DataTypes/DataTypesDecimal.cpp b/dbms/src/DataTypes/DataTypesDecimal.cpp index fb64e7446426..55282e4e9325 100644 --- a/dbms/src/DataTypes/DataTypesDecimal.cpp +++ b/dbms/src/DataTypes/DataTypesDecimal.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -43,31 +42,11 @@ bool DataTypeDecimal::equals(const IDataType & rhs) const return false; } -template -void DataTypeDecimal::writeText(T value, WriteBuffer & ostr) const -{ - if (value < T(0)) - { - value *= T(-1); - writeChar('-', ostr); /// avoid crop leading minus when whole part is zero - } - - writeIntText(static_cast(wholePart(value)), ostr); - if (scale) - { - writeChar('.', ostr); - String str_fractional(scale, '0'); - for (Int32 pos = scale - 1; pos >= 0; --pos, value /= T(10)) - str_fractional[pos] += value % T(10); - ostr.write(str_fractional.data(), scale); - } -} - template void DataTypeDecimal::serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const { T value = static_cast(column).getData()[row_num]; - writeText(value, ostr); + writeText(value, scale, ostr); } template @@ -238,19 +217,19 @@ void registerDataTypeDecimal(DataTypeFactory & factory) template <> Decimal32 DataTypeDecimal::getScaleMultiplier(UInt32 scale_) { - return common::exp10_i32(scale_); + return decimalScaleMultiplier(scale_); } template <> Decimal64 DataTypeDecimal::getScaleMultiplier(UInt32 scale_) { - return common::exp10_i64(scale_); + return decimalScaleMultiplier(scale_); } template <> Decimal128 DataTypeDecimal::getScaleMultiplier(UInt32 scale_) { - return common::exp10_i128(scale_); + return decimalScaleMultiplier(scale_); } diff --git a/dbms/src/DataTypes/DataTypesDecimal.h b/dbms/src/DataTypes/DataTypesDecimal.h index c74eead894ef..1ff83dd2dc0f 100644 --- a/dbms/src/DataTypes/DataTypesDecimal.h +++ b/dbms/src/DataTypes/DataTypesDecimal.h @@ -203,8 +203,6 @@ class DataTypeDecimal final : public DataTypeSimpleSerialization T parseFromString(const String & str) const; void readText(T & x, ReadBuffer & istr) const { readText(x, istr, precision, scale); } - void writeText(T value, WriteBuffer & ostr) const; - static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale); static T getScaleMultiplier(UInt32 scale); diff --git a/dbms/src/DataTypes/FieldToDataType.cpp b/dbms/src/DataTypes/FieldToDataType.cpp index 3c2e78b42951..77980c264181 100644 --- a/dbms/src/DataTypes/FieldToDataType.cpp +++ b/dbms/src/DataTypes/FieldToDataType.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -58,6 +59,24 @@ DataTypePtr FieldToDataType::operator() (const String &) const return std::make_shared(); } +DataTypePtr FieldToDataType::operator() (const DecimalField & x) const +{ + using Type = DataTypeDecimal; + return std::make_shared(Type::maxPrecision(), x.getScale()); +} + +DataTypePtr FieldToDataType::operator() (const DecimalField & x) const +{ + using Type = DataTypeDecimal; + return std::make_shared(Type::maxPrecision(), x.getScale()); +} + +DataTypePtr FieldToDataType::operator() (const DecimalField & x) const +{ + using Type = DataTypeDecimal; + return std::make_shared(Type::maxPrecision(), x.getScale()); +} + DataTypePtr FieldToDataType::operator() (const Array & x) const { diff --git a/dbms/src/DataTypes/FieldToDataType.h b/dbms/src/DataTypes/FieldToDataType.h index a60c6a725d86..dc103e246413 100644 --- a/dbms/src/DataTypes/FieldToDataType.h +++ b/dbms/src/DataTypes/FieldToDataType.h @@ -25,6 +25,9 @@ class FieldToDataType : public StaticVisitor DataTypePtr operator() (const String & x) const; DataTypePtr operator() (const Array & x) const; DataTypePtr operator() (const Tuple & x) const; + DataTypePtr operator() (const DecimalField & x) const; + DataTypePtr operator() (const DecimalField & x) const; + DataTypePtr operator() (const DecimalField & x) const; }; } diff --git a/dbms/src/Functions/FunctionsConversion.h b/dbms/src/Functions/FunctionsConversion.h index cda5808f3314..8c7575c9aa4a 100644 --- a/dbms/src/Functions/FunctionsConversion.h +++ b/dbms/src/Functions/FunctionsConversion.h @@ -236,7 +236,7 @@ struct FormatImpl> { static void execute(const FieldType x, WriteBuffer & wb, const DataTypeDecimal * type, const DateLUTImpl *) { - type->writeText(x, wb); + writeText(x, type->getScale(), wb); } }; diff --git a/dbms/src/IO/WriteHelpers.h b/dbms/src/IO/WriteHelpers.h index ff00de6d6ca6..4076d6e2cade 100644 --- a/dbms/src/IO/WriteHelpers.h +++ b/dbms/src/IO/WriteHelpers.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -711,6 +712,36 @@ inline void writeText(const UInt128 &, WriteBuffer &) throw Exception("UInt128 cannot be write as a text", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } +template inline T decimalScaleMultiplier(UInt32 scale); +template <> inline Int32 decimalScaleMultiplier(UInt32 scale) { return common::exp10_i32(scale); } +template <> inline Int64 decimalScaleMultiplier(UInt32 scale) { return common::exp10_i64(scale); } +template <> inline Int128 decimalScaleMultiplier(UInt32 scale) { return common::exp10_i128(scale); } + + +template +void writeText(Decimal value, UInt32 scale, WriteBuffer & ostr) +{ + if (value < Decimal(0)) + { + value *= Decimal(-1); + writeChar('-', ostr); /// avoid crop leading minus when whole part is zero + } + + T whole_part = value; + if (scale) + whole_part = value / decimalScaleMultiplier(scale); + + writeIntText(whole_part, ostr); + if (scale) + { + writeChar('.', ostr); + String str_fractional(scale, '0'); + for (Int32 pos = scale - 1; pos >= 0; --pos, value /= Decimal(10)) + str_fractional[pos] += value % Decimal(10); + ostr.write(str_fractional.data(), scale); + } +} + /// String, date, datetime are in single quotes with C-style escaping. Numbers - without. template inline std::enable_if_t, void>