From e2bd4856a8e50e0ea64365ee2d3d0c206ad0ccf0 Mon Sep 17 00:00:00 2001 From: Zhigao Tong Date: Sat, 29 Oct 2022 10:55:59 +0800 Subject: [PATCH] Optimize Aggregator/Join/Set keys (#6135) ref pingcap/tiflash#5294 Signed-off-by: CalvinNeo --- dbms/src/Columns/ColumnString.h | 59 +- dbms/src/Columns/ColumnVector.cpp | 7 - dbms/src/Columns/ColumnVector.h | 6 +- dbms/src/Common/ColumnsHashing.h | 154 +++- dbms/src/Common/HashTable/StringHashTable.h | 8 +- .../HashTable/TwoLevelStringHashTable.h | 8 +- .../Coprocessor/DAGExpressionAnalyzer.cpp | 8 +- .../tests/gtest_aggregation_executor.cpp | 125 +++- dbms/src/Interpreters/Aggregator.cpp | 673 ++++++++++++++---- dbms/src/Interpreters/Aggregator.h | 450 ++++++++---- dbms/src/Interpreters/Join.cpp | 50 +- dbms/src/Interpreters/Join.h | 7 +- dbms/src/Interpreters/Set.cpp | 16 +- dbms/src/Interpreters/Set.h | 9 +- dbms/src/Interpreters/SetVariants.cpp | 41 +- dbms/src/Interpreters/SetVariants.h | 24 +- dbms/src/Interpreters/sortBlock.cpp | 8 +- 17 files changed, 1283 insertions(+), 370 deletions(-) diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index ffea54625e9..c8820d9c6ed 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -19,7 +19,7 @@ #include #include #include -#include +#include class ICollator; @@ -119,13 +119,7 @@ class ColumnString final : public COWPtrHelper void insert(const Field & x) override { const auto & s = DB::get(x); - const size_t old_size = chars.size(); - const size_t size_to_append = s.size() + 1; - const size_t new_size = old_size + size_to_append; - - chars.resize(new_size); - memcpy(&chars[old_size], s.c_str(), size_to_append); - offsets.push_back(new_size); + insertData(s.data(), s.size()); } #if !__clang__ @@ -169,17 +163,25 @@ class ColumnString final : public COWPtrHelper } } - void insertData(const char * pos, size_t length) override + template + ALWAYS_INLINE inline void insertDataImpl(const char * pos, size_t length) { const size_t old_size = chars.size(); - const size_t new_size = old_size + length + 1; + const size_t new_size = old_size + length + (add_terminating_zero ? 1 : 0); chars.resize(new_size); - memcpy(&chars[old_size], pos, length); - chars[old_size + length] = 0; + inline_memcpy(&chars[old_size], pos, length); + + if constexpr (add_terminating_zero) + chars[old_size + length] = 0; offsets.push_back(new_size); } + void insertData(const char * pos, size_t length) override + { + return insertDataImpl(pos, length); + } + bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool /* force_decode */) override { insertData(raw_value.c_str() + cursor, length); @@ -188,12 +190,7 @@ class ColumnString final : public COWPtrHelper void insertDataWithTerminatingZero(const char * pos, size_t length) override { - const size_t old_size = chars.size(); - const size_t new_size = old_size + length; - - chars.resize(new_size); - memcpy(&chars[old_size], pos, length); - offsets.push_back(new_size); + return insertDataImpl(pos, length); } void popBack(size_t n) override @@ -220,24 +217,30 @@ class ColumnString final : public COWPtrHelper } res.size = sizeof(string_size) + string_size; char * pos = arena.allocContinue(res.size, begin); - memcpy(pos, &string_size, sizeof(string_size)); - memcpy(pos + sizeof(string_size), src, string_size); + std::memcpy(pos, &string_size, sizeof(string_size)); + inline_memcpy(pos + sizeof(string_size), src, string_size); res.data = pos; return res; } - const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override + inline const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr & collator) override { const size_t string_size = *reinterpret_cast(pos); pos += sizeof(string_size); - const size_t old_size = chars.size(); - const size_t new_size = old_size + string_size; - chars.resize(new_size); - memcpy(&chars[old_size], pos, string_size); - - offsets.push_back(new_size); - return pos + string_size; + if (likely(collator)) + { + // https://github.com/pingcap/tiflash/pull/6135 + // - Generate empty string column + // - Make size of `offsets` as previous way for func `ColumnString::size()` + offsets.push_back(0); + return pos + string_size; + } + else + { + insertDataWithTerminatingZero(pos, string_size); + return pos + string_size; + } } void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr & collator, String & sort_key_container) const override diff --git a/dbms/src/Columns/ColumnVector.cpp b/dbms/src/Columns/ColumnVector.cpp index f65a74a61d7..3ea8af02309 100644 --- a/dbms/src/Columns/ColumnVector.cpp +++ b/dbms/src/Columns/ColumnVector.cpp @@ -47,13 +47,6 @@ StringRef ColumnVector::serializeValueIntoArena(size_t n, Arena & arena, char return StringRef(pos, sizeof(T)); } -template -const char * ColumnVector::deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) -{ - data.push_back(*reinterpret_cast(pos)); - return pos + sizeof(T); -} - template void ColumnVector::updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const { diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index dac6dbce1f1..d5864f49cb2 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -292,7 +292,11 @@ class ColumnVector final : public COWPtrHelper(pos)); + return pos + sizeof(T); + } void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) const override; diff --git a/dbms/src/Common/ColumnsHashing.h b/dbms/src/Common/ColumnsHashing.h index dbf50175007..cd7ec2dd328 100644 --- a/dbms/src/Common/ColumnsHashing.h +++ b/dbms/src/Common/ColumnsHashing.h @@ -28,7 +28,6 @@ #include #include -#include namespace DB { @@ -114,12 +113,13 @@ struct HashMethodString ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, [[maybe_unused]] Arena * pool, std::vector & sort_key_containers) const { auto last_offset = row == 0 ? 0 : offsets[row - 1]; + // Remove last zero byte. StringRef key(chars + last_offset, offsets[row] - last_offset - 1); if constexpr (place_string_to_arena) { if (likely(collator)) - key = collator->sortKeyFastPath(key.data, key.size, sort_key_containers[0]); + key = collator->sortKey(key.data, key.size, sort_key_containers[0]); return ArenaKeyHolder{key, *pool}; } else @@ -132,6 +132,37 @@ struct HashMethodString friend class columns_hashing_impl::HashMethodBase; }; +template +struct HashMethodStringBin + : public columns_hashing_impl::HashMethodBase, Value, Mapped, false> +{ + using Self = HashMethodStringBin; + using Base = columns_hashing_impl::HashMethodBase; + + const IColumn::Offset * offsets; + const UInt8 * chars; + + HashMethodStringBin(const ColumnRawPtrs & key_columns, const Sizes & /*key_sizes*/, const TiDB::TiDBCollators &) + { + const IColumn & column = *key_columns[0]; + const auto & column_string = assert_cast(column); + offsets = column_string.getOffsets().data(); + chars = column_string.getChars().data(); + } + + ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector &) const + { + auto last_offset = row == 0 ? 0 : offsets[row - 1]; + StringRef key(chars + last_offset, offsets[row] - last_offset - 1); + key = BinCollatorSortKey(key.data, key.size); + return ArenaKeyHolder{key, *pool}; + } + +protected: + friend class columns_hashing_impl::HashMethodBase; +}; + +/* /// For the case when there is multi string key. template struct HashMethodMultiString @@ -172,8 +203,6 @@ struct HashMethodMultiString { auto num = offsets.size(); - static_assert(std::is_same_v(0)->size)>); - const char * begin = nullptr; size_t sum_size = 0; @@ -223,6 +252,123 @@ struct HashMethodMultiString protected: friend class columns_hashing_impl::HashMethodBase; }; +*/ + +static_assert(std::is_same_v(0)->size)>); + +struct KeyDescNumber64 +{ + using ColumnType = ColumnUInt64; + using AllocSize = size_t; + static constexpr size_t ElementSize = sizeof(ColumnType::value_type); + + explicit KeyDescNumber64(const IColumn * key_column_) + { + column = static_cast(key_column_); + } + static inline void serializeKey(char *& pos, const StringRef & ref) + { + std::memcpy(pos, ref.data, ElementSize); + pos += ElementSize; + } + ALWAYS_INLINE inline AllocSize getKey(ssize_t row, StringRef & ref) const + { + const auto & element = column->getElement(row); + ref = {reinterpret_cast(&element), ElementSize}; + return ElementSize; + } + const ColumnType * column{}; +}; + +struct KeyDescStringBin +{ + using ColumnType = ColumnString; + using AllocSize = size_t; + + explicit KeyDescStringBin(const IColumn * key_column_) + { + column = static_cast(key_column_); + } + static inline void serializeKey(char *& pos, const StringRef & ref) + { + std::memcpy(pos, &ref.size, sizeof(ref.size)); + pos += sizeof(ref.size); + inline_memcpy(pos, ref.data, ref.size); + pos += ref.size; + } + + template + ALWAYS_INLINE inline AllocSize getKeyImpl(ssize_t row, StringRef & key, F && fn_handle_key) const + { + const auto * offsets = column->getOffsets().data(); + const auto * chars = column->getChars().data(); + + size_t last_offset = 0; + if (likely(row != 0)) + last_offset = offsets[row - 1]; + + key = {chars + last_offset, offsets[row] - last_offset - 1}; + key = fn_handle_key(key); + + return key.size + sizeof(key.size); + } + + ALWAYS_INLINE inline AllocSize getKey(ssize_t row, StringRef & ref) const + { + return getKeyImpl(row, ref, [](StringRef key) { + return key; + }); + } + + const ColumnType * column{}; +}; + +struct KeyDescStringBinPadding : KeyDescStringBin +{ + explicit KeyDescStringBinPadding(const IColumn * key_column_) + : KeyDescStringBin(key_column_) + {} + + ALWAYS_INLINE inline AllocSize getKey(ssize_t row, StringRef & ref) const + { + return getKeyImpl(row, ref, [](StringRef key) { + return DB::BinCollatorSortKey(key.data, key.size); + }); + } +}; + +/// For the case when there are 2 keys. +template +struct HashMethodFastPathTwoKeysSerialized + : public columns_hashing_impl::HashMethodBase, Value, Mapped, false> +{ + using Self = HashMethodFastPathTwoKeysSerialized; + using Base = columns_hashing_impl::HashMethodBase; + + Key1Desc key_1_desc; + Key2Desc key_2_desc; + + HashMethodFastPathTwoKeysSerialized(const ColumnRawPtrs & key_columns, const Sizes &, const TiDB::TiDBCollators &) + : key_1_desc(key_columns[0]) + , key_2_desc(key_columns[1]) + { + } + + ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector &) const + { + StringRef key1; + StringRef key2; + size_t alloc_size = key_1_desc.getKey(row, key1) + key_2_desc.getKey(row, key2); + char * start = pool->alloc(alloc_size); + SerializedKeyHolder ret{{start, alloc_size}, *pool}; + Key1Desc::serializeKey(start, key1); + Key2Desc::serializeKey(start, key2); + return ret; + } + +protected: + friend class columns_hashing_impl::HashMethodBase; +}; /// For the case when there is one fixed-length string key. diff --git a/dbms/src/Common/HashTable/StringHashTable.h b/dbms/src/Common/HashTable/StringHashTable.h index 95fcc0ceaf5..998d90f2c91 100644 --- a/dbms/src/Common/HashTable/StringHashTable.h +++ b/dbms/src/Common/HashTable/StringHashTable.h @@ -254,7 +254,13 @@ class StringHashTable : private boost::noncopyable // 3. Funcs are named callables that can be force_inlined // NOTE: It relies on Little Endianness template - static auto ALWAYS_INLINE dispatch(Self & self, KeyHolder && key_holder, Func && func) + static auto +#if defined(ADDRESS_SANITIZER) + NO_INLINE NO_SANITIZE_ADDRESS +#else + ALWAYS_INLINE +#endif + dispatch(Self & self, KeyHolder && key_holder, Func && func) { StringHashTableHash hash; const StringRef & x = keyHolderGetKey(key_holder); diff --git a/dbms/src/Common/HashTable/TwoLevelStringHashTable.h b/dbms/src/Common/HashTable/TwoLevelStringHashTable.h index 0a135053086..d1560db2d72 100644 --- a/dbms/src/Common/HashTable/TwoLevelStringHashTable.h +++ b/dbms/src/Common/HashTable/TwoLevelStringHashTable.h @@ -88,7 +88,13 @@ class TwoLevelStringHashTable : private boost::noncopyable // This function is mostly the same as StringHashTable::dispatch, but with // added bucket computation. See the comments there. template - static auto ALWAYS_INLINE dispatch(Self & self, KeyHolder && key_holder, Func && func) + static auto +#if defined(ADDRESS_SANITIZER) + NO_INLINE NO_SANITIZE_ADDRESS +#else + ALWAYS_INLINE +#endif + dispatch(Self & self, KeyHolder && key_holder, Func && func) { StringHashTableHash hash; const StringRef & x = keyHolderGetKey(key_holder); diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 61897b229f3..ebc87049c22 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -1371,10 +1371,10 @@ void DAGExpressionAnalyzer::makeExplicitSet( set_element_types.push_back(sample_block.getByName(left_arg_name).type); // todo if this is a single value in, then convert it to equal expr - SetPtr set = std::make_shared(SizeLimits(settings.max_rows_in_set, settings.max_bytes_in_set, settings.set_overflow_mode)); - TiDB::TiDBCollators collators; - collators.push_back(getCollatorFromExpr(expr)); - set->setCollators(collators); + SetPtr set = std::make_shared( + SizeLimits(settings.max_rows_in_set, settings.max_bytes_in_set, settings.set_overflow_mode), + TiDB::TiDBCollators{getCollatorFromExpr(expr)}); + auto remaining_exprs = set->createFromDAGExpr(set_element_types, expr, create_ordered_set); prepared_sets[&expr] = std::make_shared(std::move(set), std::move(remaining_exprs)); } diff --git a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp index 3fdc4d059e2..4f033e6afe5 100644 --- a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp +++ b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp @@ -44,6 +44,8 @@ class ExecutorAggTestRunner : public ExecutorTest using ColMyDateTimeNullableType = std::optional::FieldType>; using ColDecimalNullableType = std::optional::FieldType>; using ColUInt64Type = typename TypeTraits::FieldType; + using ColFloat64Type = typename TypeTraits::FieldType; + using ColStringType = typename TypeTraits::FieldType; using ColumnWithNullableString = std::vector; using ColumnWithNullableInt8 = std::vector; @@ -56,8 +58,10 @@ class ExecutorAggTestRunner : public ExecutorTest using ColumnWithNullableMyDateTime = std::vector; using ColumnWithNullableDecimal = std::vector; using ColumnWithUInt64 = std::vector; + using ColumnWithFloat64 = std::vector; + using ColumnWithString = std::vector; - virtual ~ExecutorAggTestRunner() = default; + ~ExecutorAggTestRunner() override = default; void initializeContext() override { @@ -114,6 +118,22 @@ class ExecutorAggTestRunner : public ExecutorTest {{"s1", TiDB::TP::TypeLongLong}, {"s2", TiDB::TP::TypeLongLong}}, {toVec("s1", {1, 2, 3}), toVec("s2", {1, 2, 3})}); + + context.addMockTable({"test_db", "test_table_not_null"}, + { + {"c1_i64", TiDB::TP::TypeLongLong}, + {"c2_f64", TiDB::TP::TypeDouble}, + {"c3_str", TiDB::TP::TypeString}, + {"c4_str", TiDB::TP::TypeString}, + {"c5_date_time", TiDB::TP::TypeDatetime}, + }, + { + toVec("c1_i64", {1, 2, 2}), + toVec("c2_f64", {1, 3, 3}), + toVec("c3_str", {"1", "4 ", "4 "}), + toVec("c4_str", {"1", "2 ", "2 "}), + toVec("c5_date_time", {2000000, 12000000, 12000000}), + }); } std::shared_ptr buildDAGRequest(std::pair src, MockAstVec agg_funcs, MockAstVec group_by_exprs, MockColumnNameVec proj) @@ -322,6 +342,109 @@ try } CATCH +TEST_F(ExecutorAggTestRunner, AggregationCountGroupByFastPathMultiKeys) +try +{ + /// Prepare some data + std::shared_ptr request; + auto agg_func = Count(lit(Field(static_cast(1)))); /// select count(1) from `test_table_not_null` group by ``; + std::string agg_func_res_name = "count(1)"; + + auto group_by_expr_c1_i64 = col("c1_i64"); + auto group_by_expr_c2_f64 = col("c2_f64"); + auto group_by_expr_c3_str = col("c3_str"); + auto group_by_expr_c4_str = col("c4_str"); + auto group_by_expr_c5_date_time = col("c5_date_time"); + + std::vector group_by_exprs{ + {group_by_expr_c3_str, group_by_expr_c2_f64, group_by_expr_c1_i64}, + {group_by_expr_c1_i64, group_by_expr_c2_f64}, + {group_by_expr_c1_i64, group_by_expr_c3_str}, + {group_by_expr_c3_str, group_by_expr_c2_f64}, + {group_by_expr_c3_str, group_by_expr_c4_str}, + {group_by_expr_c1_i64}, + {group_by_expr_c3_str}, + {group_by_expr_c3_str, group_by_expr_c5_date_time}, + }; + + std::vector expect_cols{ + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 2})}, + }; + + std::vector projections{ + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + }; + size_t test_num = expect_cols.size(); + + ASSERT_EQ(test_num, projections.size()); + ASSERT_EQ(test_num, group_by_exprs.size()); + + { + context.setCollation(TiDB::ITiDBCollator::UTF8MB4_BIN); + for (size_t i = 0; i < test_num; ++i) + { + request = buildDAGRequest(std::make_pair("test_db", "test_table_not_null"), {agg_func}, group_by_exprs[i], projections[i]); + executeAndAssertColumnsEqual(request, expect_cols[i]); + } + } + { + context.setCollation(TiDB::ITiDBCollator::UTF8_UNICODE_CI); + for (size_t i = 0; i < test_num; ++i) + { + request = buildDAGRequest(std::make_pair("test_db", "test_table_not_null"), {agg_func}, group_by_exprs[i], projections[i]); + executeAndAssertColumnsEqual(request, expect_cols[i]); + } + } + for (auto collation_id : {0, static_cast(TiDB::ITiDBCollator::BINARY)}) + { + // 0: no collation + // binnary collation + context.setCollation(collation_id); + + std::vector group_by_exprs{ + {group_by_expr_c1_i64, group_by_expr_c3_str}, + {group_by_expr_c3_str, group_by_expr_c2_f64}, + {group_by_expr_c3_str, group_by_expr_c4_str}, + {group_by_expr_c3_str}, + }; + std::vector expect_cols{ + {toVec(agg_func_res_name, ColumnWithUInt64{1, 1, 1})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 1, 1})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 1, 1})}, + {toVec(agg_func_res_name, ColumnWithUInt64{1, 1, 1})}, + }; + std::vector projections{ + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + {agg_func_res_name}, + }; + size_t test_num = expect_cols.size(); + ASSERT_EQ(test_num, projections.size()); + ASSERT_EQ(test_num, group_by_exprs.size()); + for (size_t i = 0; i < test_num; ++i) + { + request = buildDAGRequest(std::make_pair("test_db", "test_table_not_null"), {agg_func}, group_by_exprs[i], projections[i]); + executeAndAssertColumnsEqual(request, expect_cols[i]); + } + } +} +CATCH + TEST_F(ExecutorAggTestRunner, AggNull) try { diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index 47fa98f8417..b7193833031 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -35,6 +35,9 @@ #include #include +#include +#include +#include #include #include #include @@ -56,6 +59,13 @@ extern const char random_aggregate_create_state_failpoint[]; extern const char random_aggregate_merge_failpoint[]; } // namespace FailPoints +#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME +#define AggregationMethodNameTwoLevel(NAME) AggregatedDataVariants::AggregationMethod_##NAME##_two_level +#define AggregationMethodType(NAME) AggregatedDataVariants::Type::NAME +#define AggregationMethodTypeTwoLevel(NAME) AggregatedDataVariants::Type::NAME##_two_level +#define ToAggregationMethodPtr(NAME, ptr) (reinterpret_cast(ptr)) +#define ToAggregationMethodPtrTwoLevel(NAME, ptr) (reinterpret_cast(ptr)) + AggregatedDataVariants::~AggregatedDataVariants() { if (aggregator && !aggregator->all_aggregates_has_trivial_destructor) @@ -69,22 +79,77 @@ AggregatedDataVariants::~AggregatedDataVariants() tryLogCurrentException(aggregator->log, __PRETTY_FUNCTION__); } } + destroyAggregationMethodImpl(); } - -void AggregatedDataVariants::convertToTwoLevel() +void AggregatedDataVariants::destroyAggregationMethodImpl() { - if (aggregator) - LOG_TRACE(aggregator->log, "Converting aggregation data to two-level."); + if (!aggregation_method_impl) + return; +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + delete reinterpret_cast(aggregation_method_impl); \ + aggregation_method_impl = nullptr; \ + break; \ + } switch (type) { -#define M(NAME) \ - case Type::NAME: \ - NAME##_two_level = std::make_unique(*(NAME)); \ - (NAME).reset(); \ - type = Type::NAME##_two_level; \ + APPLY_FOR_AGGREGATED_VARIANTS(M) + default: + break; + } +#undef M +} + +void AggregatedDataVariants::init(Type variants_type) +{ + destroyAggregationMethodImpl(); + + switch (variants_type) + { + case Type::EMPTY: break; + case Type::without_key: + break; + +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + aggregation_method_impl = std::make_unique().release(); \ + break; \ + } + + APPLY_FOR_AGGREGATED_VARIANTS(M) +#undef M + + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + } + + type = variants_type; +} + +void AggregatedDataVariants::convertToTwoLevel() +{ + switch (type) + { +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + if (aggregator) \ + LOG_TRACE(aggregator->log, \ + "Converting aggregation data type `{}` to `{}`.", \ + getMethodName(AggregationMethodType(NAME)), \ + getMethodName(AggregationMethodTypeTwoLevel(NAME))); \ + auto ori_ptr = ToAggregationMethodPtr(NAME, aggregation_method_impl); \ + auto two_level = std::make_unique(*ori_ptr); \ + delete ori_ptr; \ + aggregation_method_impl = two_level.release(); \ + type = AggregationMethodTypeTwoLevel(NAME); \ + break; \ + } APPLY_FOR_VARIANTS_CONVERTIBLE_TO_TWO_LEVEL(M) @@ -154,7 +219,7 @@ Block Aggregator::Params::getHeader( Aggregator::Aggregator(const Params & params_, const String & req_id) : params(params_) , log(Logger::get(req_id)) - , isCancelled([]() { return false; }) + , is_cancelled([]() { return false; }) { if (current_memory_tracker) memory_usage_before_aggregation = current_memory_tracker->get(); @@ -201,6 +266,124 @@ Aggregator::Aggregator(const Params & params_, const String & req_id) } +inline bool IsTypeNumber64(const DataTypePtr & type) +{ + return type->isNumber() && type->getSizeOfValueInMemory() == sizeof(uint64_t); +} + +#define APPLY_FOR_AGG_FAST_PATH_TYPES(M) \ + M(Number64) \ + M(StringBin) \ + M(StringBinPadding) + +enum class AggFastPathType +{ +#define M(NAME) NAME, + APPLY_FOR_AGG_FAST_PATH_TYPES(M) +#undef M +}; + +AggregatedDataVariants::Type ChooseAggregationMethodTwoKeys(const AggFastPathType * fast_path_types) +{ + auto tp1 = fast_path_types[0]; + auto tp2 = fast_path_types[1]; + switch (tp1) + { + case AggFastPathType::Number64: + { + switch (tp2) + { + case AggFastPathType::Number64: + return AggregatedDataVariants::Type::serialized; // unreachable. keys64 or keys128 will be used before + case AggFastPathType::StringBin: + return AggregatedDataVariants::Type::two_keys_num64_strbin; + case AggFastPathType::StringBinPadding: + return AggregatedDataVariants::Type::two_keys_num64_strbinpadding; + } + } + case AggFastPathType::StringBin: + { + switch (tp2) + { + case AggFastPathType::Number64: + return AggregatedDataVariants::Type::two_keys_strbin_num64; + case AggFastPathType::StringBin: + return AggregatedDataVariants::Type::two_keys_strbin_strbin; + case AggFastPathType::StringBinPadding: + return AggregatedDataVariants::Type::serialized; // rare case + } + } + case AggFastPathType::StringBinPadding: + { + switch (tp2) + { + case AggFastPathType::Number64: + return AggregatedDataVariants::Type::two_keys_strbinpadding_num64; + case AggFastPathType::StringBin: + return AggregatedDataVariants::Type::serialized; // rare case + case AggFastPathType::StringBinPadding: + return AggregatedDataVariants::Type::two_keys_strbinpadding_strbinpadding; + } + } + } +} + +// return AggregatedDataVariants::Type::serialized if can NOT determine fast path. +AggregatedDataVariants::Type ChooseAggregationMethodFastPath(size_t keys_size, const DataTypes & types_not_null, const TiDB::TiDBCollators & collators) +{ + std::array fast_path_types{}; + + if (keys_size == fast_path_types.max_size()) + { + for (size_t i = 0; i < keys_size; ++i) + { + const auto & type = types_not_null[i]; + if (type->isString()) + { + if (collators.empty() || !collators[i]) + { + // use original way + return AggregatedDataVariants::Type::serialized; + } + else + { + switch (collators[i]->getCollatorType()) + { + case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN: + case TiDB::ITiDBCollator::CollatorType::UTF8_BIN: + case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN: + case TiDB::ITiDBCollator::CollatorType::ASCII_BIN: + { + fast_path_types[i] = AggFastPathType::StringBinPadding; + break; + } + case TiDB::ITiDBCollator::CollatorType::BINARY: + { + fast_path_types[i] = AggFastPathType::StringBin; + break; + } + default: + { + // for CI COLLATION, use original way + return AggregatedDataVariants::Type::serialized; + } + } + } + } + else if (IsTypeNumber64(type)) + { + fast_path_types[i] = AggFastPathType::Number64; + } + else + { + return AggregatedDataVariants::Type::serialized; + } + } + return ChooseAggregationMethodTwoKeys(fast_path_types.data()); + } + return AggregatedDataVariants::Type::serialized; +} + AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() { /// If no keys. All aggregating to single row. @@ -263,11 +446,13 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() } /// No key has been found to be nullable. + const DataTypes & types_not_null = types_removed_nullable; + assert(!has_nullable_key); /// Single numeric key. - if (params.keys_size == 1 && types_removed_nullable[0]->isValueRepresentedByNumber()) + if (params.keys_size == 1 && types_not_null[0]->isValueRepresentedByNumber()) { - size_t size_of_field = types_removed_nullable[0]->getSizeOfValueInMemory(); + size_t size_of_field = types_not_null[0]->getSizeOfValueInMemory(); if (size_of_field == 1) return AggregatedDataVariants::Type::key8; if (size_of_field == 2) @@ -301,28 +486,41 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() } /// If single string key - will use hash table with references to it. Strings itself are stored separately in Arena. - if (params.keys_size == 1 && types_removed_nullable[0]->isString()) - return AggregatedDataVariants::Type::key_string; - - if (params.keys_size > 1 && types_removed_nullable[0]->isString()) + if (params.keys_size == 1 && types_not_null[0]->isString()) { - bool is_all_str = std::all_of(types_removed_nullable.data(), types_removed_nullable.data() + params.keys_size, [](const auto & x) { - return x->isString(); - }); - - if (is_all_str) + if (params.collators.empty() || !params.collators[0]) { - return AggregatedDataVariants::Type::multi_key_string; + // use original way. `Type::one_key_strbin` will generate empty column. + return AggregatedDataVariants::Type::key_string; + } + else + { + switch (params.collators[0]->getCollatorType()) + { + case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN: + case TiDB::ITiDBCollator::CollatorType::UTF8_BIN: + case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN: + case TiDB::ITiDBCollator::CollatorType::ASCII_BIN: + { + return AggregatedDataVariants::Type::one_key_strbinpadding; + } + case TiDB::ITiDBCollator::CollatorType::BINARY: + { + return AggregatedDataVariants::Type::one_key_strbin; + } + default: + { + // for CI COLLATION, use original way + return AggregatedDataVariants::Type::key_string; + } + } } } - if (params.keys_size == 1 && types_removed_nullable[0]->isFixedString()) + if (params.keys_size == 1 && types_not_null[0]->isFixedString()) return AggregatedDataVariants::Type::key_fixed_string; - /// Fallback case. - return AggregatedDataVariants::Type::serialized; - - /// NOTE AggregatedDataVariants::Type::hashed is not used. It's proven to be less efficient than 'serialized' in most cases. + return ChooseAggregationMethodFastPath(params.keys_size, types_not_null, params.collators); } @@ -399,7 +597,7 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( } /// Optimization for special case when aggregating by 8bit key. - if constexpr (!no_more_keys && std::is_same_v) + if constexpr (!no_more_keys && std::is_same_v) { for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst) { @@ -538,7 +736,7 @@ bool Aggregator::executeOnBlock( Int64 & local_delta_memory, bool & no_more_keys) { - if (isCancelled()) + if (is_cancelled()) return true; /// `result` will destroy the states of aggregate functions in the destructor @@ -550,8 +748,7 @@ bool Aggregator::executeOnBlock( result.init(method_chosen); result.keys_size = params.keys_size; result.key_sizes = key_sizes; - result.collators = params.collators; - LOG_TRACE(log, "Aggregation method: {}", result.getMethodName()); + LOG_TRACE(log, "Aggregation method: `{}`", result.getMethodName()); } /** Constant columns are not supported directly during aggregation. @@ -576,7 +773,7 @@ bool Aggregator::executeOnBlock( AggregateFunctionInstructions aggregate_functions_instructions; prepareAggregateInstructions(columns, aggregate_columns, materialized_columns, aggregate_functions_instructions); - if (isCancelled()) + if (is_cancelled()) return true; size_t num_rows = block.rows(); @@ -600,14 +797,20 @@ bool Aggregator::executeOnBlock( /// This is where data is written that does not fit in `max_rows_to_group_by` with `group_by_overflow_mode = any`. AggregateDataPtr overflow_row_ptr = params.overflow_row ? result.without_key : nullptr; -#define M(NAME, IS_TWO_LEVEL) \ - else if (result.type == AggregatedDataVariants::Type::NAME) \ - executeImpl(*result.NAME, result.aggregates_pool, num_rows, key_columns, result.collators, aggregate_functions_instructions.data(), no_more_keys, overflow_row_ptr); +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + executeImpl(*ToAggregationMethodPtr(NAME, result.aggregation_method_impl), result.aggregates_pool, num_rows, key_columns, params.collators, aggregate_functions_instructions.data(), no_more_keys, overflow_row_ptr); \ + break; \ + } - if (false) // NOLINT + switch (result.type) { + APPLY_FOR_AGGREGATED_VARIANTS(M) + default: + break; } - APPLY_FOR_AGGREGATED_VARIANTS(M) + #undef M } @@ -668,16 +871,24 @@ void Aggregator::writeToTemporaryFile(AggregatedDataVariants & data_variants, co /// Flush only two-level data and possibly overflow data. -#define M(NAME) \ - else if (data_variants.type == AggregatedDataVariants::Type::NAME) \ - writeToTemporaryFileImpl(data_variants, *data_variants.NAME, block_out); +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + writeToTemporaryFileImpl( \ + data_variants, \ + *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl), \ + block_out); \ + break; \ + } - if (false) // NOLINT + switch (data_variants.type) { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_TWO_LEVEL(M) + #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); /// NOTE Instead of freeing up memory and creating new hash tables and arenas, you can re-use the old ones. data_variants.init(data_variants.type); @@ -802,7 +1013,7 @@ bool Aggregator::checkLimits(size_t result_size, bool & no_more_keys) const void Aggregator::execute(const BlockInputStreamPtr & stream, AggregatedDataVariants & result, const FileProviderPtr & file_provider) { - if (isCancelled()) + if (is_cancelled()) return; ColumnRawPtrs key_columns(params.keys_size); @@ -825,7 +1036,7 @@ void Aggregator::execute(const BlockInputStreamPtr & stream, AggregatedDataVaria /// Read all the data while (Block block = stream->read()) { - if (isCancelled()) + if (is_cancelled()) return; src_rows += block.rows(); @@ -951,6 +1162,67 @@ inline void Aggregator::insertAggregatesIntoColumns( std::rethrow_exception(exception); } +template +struct AggregatorMethodInitKeyColumnHelper +{ + Method & method; + explicit AggregatorMethodInitKeyColumnHelper(Method & method_) + : method(method_) + {} + ALWAYS_INLINE inline void initAggKeys(size_t, std::vector &) {} + template + ALWAYS_INLINE inline void insertKeyIntoColumns(const Key & key, std::vector & key_columns, const Sizes & sizes, const TiDB::TiDBCollators & collators) + { + method.insertKeyIntoColumns(key, key_columns, sizes, collators); + } +}; + +template +struct AggregatorMethodInitKeyColumnHelper> +{ + using Method = AggregationMethodFastPathTwoKeysNoCache; + size_t index{}; + + Method & method; + explicit AggregatorMethodInitKeyColumnHelper(Method & method_) + : method(method_) + {} + + ALWAYS_INLINE inline void initAggKeys(size_t rows, std::vector & key_columns) + { + Method::template initAggKeys(rows, key_columns[0]); + Method::template initAggKeys(rows, key_columns[1]); + index = 0; + } + ALWAYS_INLINE inline void insertKeyIntoColumns(const StringRef & key, std::vector & key_columns, const Sizes &, const TiDB::TiDBCollators &) + { + method.insertKeyIntoColumns(key, key_columns, index); + ++index; + } +}; + +template +struct AggregatorMethodInitKeyColumnHelper> +{ + using Method = AggregationMethodOneKeyStringNoCache; + size_t index{}; + + Method & method; + explicit AggregatorMethodInitKeyColumnHelper(Method & method_) + : method(method_) + {} + + void initAggKeys(size_t rows, std::vector & key_columns) + { + Method::initAggKeys(rows, key_columns[0]); + index = 0; + } + ALWAYS_INLINE inline void insertKeyIntoColumns(const StringRef & key, std::vector & key_columns, const Sizes &, const TiDB::TiDBCollators &) + { + method.insertKeyIntoColumns(key, key_columns, index); + ++index; + } +}; template void NO_INLINE Aggregator::convertToBlockImplFinal( @@ -963,8 +1235,11 @@ void NO_INLINE Aggregator::convertToBlockImplFinal( auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes); const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes; + AggregatorMethodInitKeyColumnHelper agg_keys_helper{method}; + agg_keys_helper.initAggKeys(data.size(), key_columns); + data.forEachValue([&](const auto & key, auto & mapped) { - method.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena); }); } @@ -979,8 +1254,11 @@ void NO_INLINE Aggregator::convertToBlockImplNotFinal( auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes); const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes; + AggregatorMethodInitKeyColumnHelper agg_keys_helper{method}; + agg_keys_helper.initAggKeys(data.size(), key_columns); + data.forEachValue([&](const auto & key, auto & mapped) { - method.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); /// reserved, so push_back does not throw exceptions for (size_t i = 0; i < params.aggregates_size; ++i) @@ -1123,15 +1401,26 @@ Block Aggregator::prepareBlockAndFillSingleLevel(AggregatedDataVariants & data_v AggregateColumnsData & aggregate_columns, MutableColumns & final_aggregate_columns, bool final_) { -#define M(NAME) \ - else if (data_variants.type == AggregatedDataVariants::Type::NAME) \ - convertToBlockImpl(*data_variants.NAME, data_variants.NAME->data, key_columns, aggregate_columns, final_aggregate_columns, data_variants.aggregates_pool, final_); - if (false) // NOLINT +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + convertToBlockImpl( \ + *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl), \ + ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl)->data, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + data_variants.aggregates_pool, \ + final_); \ + break; \ + } + switch (data_variants.type) { + APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); }; return prepareBlockAndFill(data_variants, final, rows, filler); @@ -1144,15 +1433,25 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevel( ThreadPoolManager * thread_pool, size_t max_threads) const { -#define M(NAME) \ - else if (data_variants.type == AggregatedDataVariants::Type::NAME) return prepareBlocksAndFillTwoLevelImpl(data_variants, *data_variants.NAME, final, thread_pool, max_threads); +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + return prepareBlocksAndFillTwoLevelImpl( \ + data_variants, \ + *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl), \ + final, \ + thread_pool, \ + max_threads); \ + break; \ + } - if (false) // NOLINT + switch (data_variants.type) { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_TWO_LEVEL(M) #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } @@ -1234,7 +1533,7 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevelImpl( BlocksList Aggregator::convertToBlocks(AggregatedDataVariants & data_variants, bool final, size_t max_threads) const { - if (isCancelled()) + if (is_cancelled()) return BlocksList(); LOG_TRACE(log, "Converting aggregated data to blocks"); @@ -1252,7 +1551,7 @@ BlocksList Aggregator::convertToBlocks(AggregatedDataVariants & data_variants, b && data_variants.isTwoLevel()) /// TODO Use the shared thread pool with the `merge` function. thread_pool = newThreadPoolManager(max_threads); - if (isCancelled()) + if (is_cancelled()) return BlocksList(); if (data_variants.without_key) @@ -1261,7 +1560,7 @@ BlocksList Aggregator::convertToBlocks(AggregatedDataVariants & data_variants, b final, data_variants.type != AggregatedDataVariants::Type::without_key)); - if (isCancelled()) + if (is_cancelled()) return BlocksList(); if (data_variants.type != AggregatedDataVariants::Type::without_key) @@ -1279,7 +1578,7 @@ BlocksList Aggregator::convertToBlocks(AggregatedDataVariants & data_variants, b data_variants.aggregator = nullptr; } - if (isCancelled()) + if (is_cancelled()) return BlocksList(); size_t rows = 0; @@ -1445,8 +1744,8 @@ void NO_INLINE Aggregator::mergeSingleLevelDataImpl( } } -#define M(NAME) \ - template void NO_INLINE Aggregator::mergeSingleLevelDataImpl( \ +#define M(NAME) \ + template void NO_INLINE Aggregator::mergeSingleLevelDataImpl( \ ManyAggregatedDataVariants & non_empty_data) const; APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) #undef M @@ -1461,7 +1760,7 @@ void NO_INLINE Aggregator::mergeBucketImpl( AggregatedDataVariantsPtr & res = data[0]; for (size_t result_num = 1, size = data.size(); result_num < size; ++result_num) { - if (isCancelled()) + if (is_cancelled()) return; AggregatedDataVariants & current = *data[result_num]; @@ -1552,16 +1851,19 @@ class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream ++current_bucket_num; -#define M(NAME) \ - else if (first->type == AggregatedDataVariants::Type::NAME) \ - aggregator.mergeSingleLevelDataImplNAME)::element_type>(data); - if (false) // NOLINT +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + aggregator.mergeSingleLevelDataImpl(data); \ + break; \ + } + switch (first->type) { + APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); - return aggregator.prepareBlockAndFillSingleLevel(*first, final); } else @@ -1653,15 +1955,24 @@ class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream size_t thread_number = static_cast(bucket_num) % threads; Arena * arena = merged_data.aggregates_pools.at(thread_number).get(); - if (false) {} // NOLINT -#define M(NAME) \ - else if (method == AggregatedDataVariants::Type::NAME) \ - { \ - aggregator.mergeBucketImpl(data, bucket_num, arena); \ - block = aggregator.convertOneBucketToBlock(merged_data, *merged_data.NAME, arena, final, bucket_num); \ - } - - APPLY_FOR_VARIANTS_TWO_LEVEL(M) +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + aggregator.mergeBucketImpl(data, bucket_num, arena); \ + block = aggregator.convertOneBucketToBlock( \ + merged_data, \ + *ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \ + arena, \ + final, \ + bucket_num); \ + break; \ + } + switch (method) + { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + break; + } #undef M std::lock_guard lock(parallel_merge_data->mutex); @@ -1933,7 +2244,7 @@ void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataVariants & result, size_t max_threads) { - if (isCancelled()) + if (is_cancelled()) return; /** If the remote servers used a two-level aggregation method, @@ -1950,7 +2261,7 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV size_t total_input_blocks = 0; while (Block block = stream->read()) { - if (isCancelled()) + if (is_cancelled()) return; total_input_rows += block.rows(); @@ -1973,15 +2284,22 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV if (has_two_level) { #define M(NAME) \ - if (method_chosen == AggregatedDataVariants::Type::NAME) \ - method_chosen = AggregatedDataVariants::Type::NAME##_two_level; - - APPLY_FOR_VARIANTS_CONVERTIBLE_TO_TWO_LEVEL(M) + case AggregationMethodType(NAME): \ + { \ + method_chosen = AggregationMethodTypeTwoLevel(NAME); \ + break; \ + } + switch (method_chosen) + { + APPLY_FOR_VARIANTS_CONVERTIBLE_TO_TWO_LEVEL(M) + default: + break; + } #undef M } - if (isCancelled()) + if (is_cancelled()) return; /// result will destroy the states of aggregate functions in the destructor @@ -2006,19 +2324,28 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV auto merge_bucket = [&bucket_to_blocks, &result, this](Int32 bucket, Arena * aggregates_pool) { for (Block & block : bucket_to_blocks[bucket]) { - if (isCancelled()) + if (is_cancelled()) return; -#define M(NAME) \ - else if (result.type == AggregatedDataVariants::Type::NAME) \ - mergeStreamsImpl(block, aggregates_pool, *result.NAME, result.NAME->data.impls[bucket], nullptr, false); +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + mergeStreamsImpl(block, \ + aggregates_pool, \ + *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ + ToAggregationMethodPtr(NAME, result.aggregation_method_impl)->data.impls[bucket], \ + nullptr, \ + false); \ + break; \ + } - if (false) // NOLINT + switch (result.type) { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_TWO_LEVEL(M) #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } }; @@ -2053,7 +2380,7 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV LOG_TRACE(log, "Merged partially aggregated two-level data."); } - if (isCancelled()) + if (is_cancelled()) { result.invalidate(); return; @@ -2068,7 +2395,7 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV BlocksList & blocks = bucket_to_blocks[-1]; for (Block & block : blocks) { - if (isCancelled()) + if (is_cancelled()) { result.invalidate(); return; @@ -2080,13 +2407,28 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV if (result.type == AggregatedDataVariants::Type::without_key || block.info.is_overflows) mergeWithoutKeyStreamsImpl(block, result); -#define M(NAME, IS_TWO_LEVEL) \ - else if (result.type == AggregatedDataVariants::Type::NAME) \ - mergeStreamsImpl(block, result.aggregates_pool, *result.NAME, result.NAME->data, result.without_key, no_more_keys); - - APPLY_FOR_AGGREGATED_VARIANTS(M) +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + mergeStreamsImpl(block, \ + result.aggregates_pool, \ + *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ + ToAggregationMethodPtr(NAME, result.aggregation_method_impl)->data, \ + result.without_key, \ + no_more_keys); \ + break; \ + } + switch (result.type) + { + APPLY_FOR_AGGREGATED_VARIANTS(M) + case AggregatedDataVariants::Type::without_key: + break; + default: + { + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + } + } #undef M - else if (result.type != AggregatedDataVariants::Type::without_key) throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } LOG_TRACE(log, "Merged partially aggregated single-level data."); @@ -2102,7 +2444,7 @@ Block Aggregator::mergeBlocks(BlocksList & blocks, bool final) auto bucket_num = blocks.front().info.bucket_num; bool is_overflows = blocks.front().info.is_overflows; - LOG_TRACE(log, "Merging partially aggregated blocks (bucket = {}).", bucket_num); + LOG_TRACE(log, "Merging partially aggregated blocks (bucket = {}). Original method `{}`.", bucket_num, AggregatedDataVariants::getMethodName(method_chosen)); Stopwatch watch; /** If possible, change 'method' to some_hash64. Otherwise, leave as is. @@ -2119,11 +2461,19 @@ Block Aggregator::mergeBlocks(BlocksList & blocks, bool final) M(keys256) \ M(serialized) -#define M(NAME) \ - if (merge_method == AggregatedDataVariants::Type::NAME) \ - merge_method = AggregatedDataVariants::Type::NAME##_hash64; +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + merge_method = AggregatedDataVariants::Type::NAME##_hash64; \ + break; \ + } - APPLY_FOR_VARIANTS_THAT_MAY_USE_BETTER_HASH_FUNCTION(M) + switch (merge_method) + { + APPLY_FOR_VARIANTS_THAT_MAY_USE_BETTER_HASH_FUNCTION(M) + default: + break; + } #undef M #undef APPLY_FOR_VARIANTS_THAT_MAY_USE_BETTER_HASH_FUNCTION @@ -2146,13 +2496,26 @@ Block Aggregator::mergeBlocks(BlocksList & blocks, bool final) if (result.type == AggregatedDataVariants::Type::without_key || is_overflows) mergeWithoutKeyStreamsImpl(block, result); -#define M(NAME, IS_TWO_LEVEL) \ - else if (result.type == AggregatedDataVariants::Type::NAME) \ - mergeStreamsImpl(block, result.aggregates_pool, *result.NAME, result.NAME->data, nullptr, false); - - APPLY_FOR_AGGREGATED_VARIANTS(M) +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + mergeStreamsImpl(block, \ + result.aggregates_pool, \ + *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ + ToAggregationMethodPtr(NAME, result.aggregation_method_impl)->data, \ + nullptr, \ + false); \ + break; \ + } + switch (result.type) + { + APPLY_FOR_AGGREGATED_VARIANTS(M) + case AggregatedDataVariants::Type::without_key: + break; + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + } #undef M - else if (result.type != AggregatedDataVariants::Type::without_key) throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } Block block; @@ -2255,46 +2618,65 @@ std::vector Aggregator::convertBlockToTwoLevel(const Block & block) data.keys_size = params.keys_size; data.key_sizes = key_sizes; -#define M(NAME) \ - else if (type == AggregatedDataVariants::Type::NAME) \ - type \ - = AggregatedDataVariants::Type::NAME##_two_level; +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + type = AggregationMethodTypeTwoLevel(NAME); \ + break; \ + } - if (false) // NOLINT + switch (type) { + APPLY_FOR_VARIANTS_CONVERTIBLE_TO_TWO_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_CONVERTIBLE_TO_TWO_LEVEL(M) #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); data.init(type); size_t num_buckets = 0; -#define M(NAME) \ - else if (data.type == AggregatedDataVariants::Type::NAME) \ - num_buckets \ - = data.NAME->data.NUM_BUCKETS; +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + num_buckets \ + = ToAggregationMethodPtr(NAME, data.aggregation_method_impl)->data.NUM_BUCKETS; \ + break; \ + } - if (false) // NOLINT + switch (data.type) { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_TWO_LEVEL(M) + #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + std::vector splitted_blocks(num_buckets); -#define M(NAME) \ - else if (data.type == AggregatedDataVariants::Type::NAME) \ - convertBlockToTwoLevelImpl(*data.NAME, data.aggregates_pool, key_columns, block, splitted_blocks); +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + convertBlockToTwoLevelImpl( \ + *ToAggregationMethodPtr(NAME, data.aggregation_method_impl), \ + data.aggregates_pool, \ + key_columns, \ + block, \ + splitted_blocks); \ + break; \ + } - if (false) // NOLINT + switch (data.type) { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_VARIANTS_TWO_LEVEL(M) #undef M - else throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + return splitted_blocks; } @@ -2344,22 +2726,29 @@ void Aggregator::destroyAllAggregateStates(AggregatedDataVariants & result) if (result.type == AggregatedDataVariants::Type::without_key || params.overflow_row) destroyWithoutKey(result); -#define M(NAME, IS_TWO_LEVEL) \ - else if (result.type == AggregatedDataVariants::Type::NAME) \ - destroyImpl(result.NAME->data); +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + destroyImpl(ToAggregationMethodPtr(NAME, result.aggregation_method_impl)->data); \ + break; \ + } - if (false) // NOLINT + switch (result.type) { + APPLY_FOR_AGGREGATED_VARIANTS(M) + case AggregatedDataVariants::Type::without_key: + break; + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } - APPLY_FOR_AGGREGATED_VARIANTS(M) + #undef M - else if (result.type != AggregatedDataVariants::Type::without_key) throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); } -void Aggregator::setCancellationHook(const CancellationHook cancellation_hook) +void Aggregator::setCancellationHook(CancellationHook cancellation_hook) { - isCancelled = cancellation_hook; + is_cancelled = cancellation_hook; } diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 688298cb14d..b779aa808ab 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -126,7 +126,7 @@ struct AggregationMethodOneNumber AggregationMethodOneNumber() = default; template - AggregationMethodOneNumber(const Other & other) + explicit AggregationMethodOneNumber(const Other & other) : data(other.data) {} @@ -161,7 +161,7 @@ struct AggregationMethodString AggregationMethodString() = default; template - AggregationMethodString(const Other & other) + explicit AggregationMethodString(const Other & other) : data(other.data) {} @@ -192,16 +192,50 @@ struct AggregationMethodStringNoCache : data(other.data) {} + // Remove last zero byte. using State = ColumnsHashing::HashMethodString; std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } static void insertKeyIntoColumns(const StringRef & key, std::vector & key_columns, const Sizes &, const TiDB::TiDBCollators &) { + // Add last zero byte. static_cast(key_columns[0])->insertData(key.data, key.size); } }; +template +struct AggregationMethodOneKeyStringNoCache +{ + using Data = TData; + using Key = typename Data::key_type; + using Mapped = typename Data::mapped_type; + + Data data; + + AggregationMethodOneKeyStringNoCache() = default; + + template + explicit AggregationMethodOneKeyStringNoCache(const Other & other) + : data(other.data) + {} + + using State = ColumnsHashing::HashMethodStringBin; + + std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } + + ALWAYS_INLINE static inline void insertKeyIntoColumns(const StringRef &, std::vector &, size_t) + { + // insert empty because such column will be discarded. + } + // resize offsets for column string + ALWAYS_INLINE static inline void initAggKeys(size_t rows, IColumn * key_column) + { + static_cast(key_column)->getOffsets().resize_fill(rows, 0); + } +}; + +/* /// Same as above but without cache template struct AggregationMethodMultiStringNoCache @@ -230,6 +264,91 @@ struct AggregationMethodMultiStringNoCache pos = static_cast(key_column)->deserializeAndInsertFromArena(pos, nullptr); } }; +*/ + +template +struct AggregationMethodFastPathTwoKeysNoCache +{ + using Data = TData; + using Key = typename Data::key_type; + using Mapped = typename Data::mapped_type; + + Data data; + + AggregationMethodFastPathTwoKeysNoCache() = default; + + template + explicit AggregationMethodFastPathTwoKeysNoCache(const Other & other) + : data(other.data) + {} + + using State = ColumnsHashing::HashMethodFastPathTwoKeysSerialized; + + std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } + + template + ALWAYS_INLINE static inline void initAggKeys(size_t rows, IColumn * key_column) + { + auto * column = static_cast(key_column); + column->getData().resize_fill(rows, 0); + } + + // Only update offsets but DO NOT insert string data. + // Because of https://github.com/pingcap/tiflash/blob/84c2650bc4320919b954babeceb5aeaadb845770/dbms/src/Columns/IColumn.h#L160-L173, such column will be discarded. + ALWAYS_INLINE static inline const char * insertAggKeyIntoColumnString(const char * pos, IColumn *) + { + const size_t string_size = *reinterpret_cast(pos); + pos += sizeof(string_size); + return pos + string_size; + } + // resize offsets for column string + ALWAYS_INLINE static inline void initAggKeyString(size_t rows, IColumn * key_column) + { + auto * column = static_cast(key_column); + column->getOffsets().resize_fill(rows, 0); + } + + template <> + ALWAYS_INLINE static inline void initAggKeys(size_t rows, IColumn * key_column) + { + return initAggKeyString(rows, key_column); + } + template <> + ALWAYS_INLINE static inline void initAggKeys(size_t rows, IColumn * key_column) + { + return initAggKeyString(rows, key_column); + } + + template + ALWAYS_INLINE static inline const char * insertAggKeyIntoColumn(const char * pos, IColumn * key_column, size_t index) + { + auto * column = static_cast(key_column); + column->getElement(index) = *reinterpret_cast(pos); + return pos + KeyType::ElementSize; + } + template <> + ALWAYS_INLINE static inline const char * insertAggKeyIntoColumn(const char * pos, IColumn * key_column, size_t) + { + return insertAggKeyIntoColumnString(pos, key_column); + } + template <> + ALWAYS_INLINE static inline const char * insertAggKeyIntoColumn(const char * pos, IColumn * key_column, size_t) + { + return insertAggKeyIntoColumnString(pos, key_column); + } + + ALWAYS_INLINE static inline void insertKeyIntoColumns(const StringRef & key, std::vector & key_columns, size_t index) + { + const auto * pos = key.data; + { + pos = insertAggKeyIntoColumn(pos, key_columns[0], index); + } + { + pos = insertAggKeyIntoColumn(pos, key_columns[1], index); + } + } +}; + /// For the case where there is one fixed-length string key. template @@ -244,7 +363,7 @@ struct AggregationMethodFixedString AggregationMethodFixedString() = default; template - AggregationMethodFixedString(const Other & other) + explicit AggregationMethodFixedString(const Other & other) : data(other.data) {} @@ -271,7 +390,7 @@ struct AggregationMethodFixedStringNoCache AggregationMethodFixedStringNoCache() = default; template - AggregationMethodFixedStringNoCache(const Other & other) + explicit AggregationMethodFixedStringNoCache(const Other & other) : data(other.data) {} @@ -300,7 +419,7 @@ struct AggregationMethodKeysFixed AggregationMethodKeysFixed() = default; template - AggregationMethodKeysFixed(const Other & other) + explicit AggregationMethodKeysFixed(const Other & other) : data(other.data) {} @@ -336,7 +455,7 @@ struct AggregationMethodKeysFixed /// If we have a nullable column, get its nested column and its null map. if (column_nullable) { - ColumnNullable & nullable_col = assert_cast(*key_columns[i]); + auto & nullable_col = assert_cast(*key_columns[i]); observed_column = &nullable_col.getNestedColumn(); null_map = assert_cast(&nullable_col.getNullMapColumn()); } @@ -388,7 +507,7 @@ struct AggregationMethodSerialized AggregationMethodSerialized() = default; template - AggregationMethodSerialized(const Other & other) + explicit AggregationMethodSerialized(const Other & other) : data(other.data) {} @@ -407,6 +526,8 @@ struct AggregationMethodSerialized class Aggregator; +#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME + struct AggregatedDataVariants : private boost::noncopyable { /** Working with states of aggregate functions in the pool is arranged in the following (inconvenient) way: @@ -430,94 +551,126 @@ struct AggregatedDataVariants : private boost::noncopyable size_t keys_size{}; /// Number of keys. NOTE do we need this field? Sizes key_sizes; /// Dimensions of keys, if keys of fixed length - TiDB::TiDBCollators collators; /// Pools for states of aggregate functions. Ownership will be later transferred to ColumnAggregateFunction. Arenas aggregates_pools; Arena * aggregates_pool{}; /// The pool that is currently used for allocation. + void * aggregation_method_impl{}; + /** Specialization for the case when there are no keys, and for keys not fitted into max_rows_to_group_by. */ AggregatedDataWithoutKey without_key = nullptr; - std::unique_ptr> key8; - std::unique_ptr> key16; - - std::unique_ptr> key32; - std::unique_ptr> key64; - std::unique_ptr> key_int256; - std::unique_ptr> key_string; - std::unique_ptr> multi_key_string; - std::unique_ptr> key_fixed_string; - std::unique_ptr> keys16; - std::unique_ptr> keys32; - std::unique_ptr> keys64; - std::unique_ptr> keys128; - std::unique_ptr> keys256; - std::unique_ptr> serialized; - - std::unique_ptr> key32_two_level; - std::unique_ptr> key64_two_level; - std::unique_ptr> key_int256_two_level; - std::unique_ptr> key_string_two_level; - std::unique_ptr> multi_key_string_two_level; - std::unique_ptr> key_fixed_string_two_level; - std::unique_ptr> keys32_two_level; - std::unique_ptr> keys64_two_level; - std::unique_ptr> keys128_two_level; - std::unique_ptr> keys256_two_level; - std::unique_ptr> serialized_two_level; - - std::unique_ptr> key64_hash64; - std::unique_ptr> key_string_hash64; - std::unique_ptr> key_fixed_string_hash64; - std::unique_ptr> keys128_hash64; - std::unique_ptr> keys256_hash64; - std::unique_ptr> serialized_hash64; + using AggregationMethod_key8 = AggregationMethodOneNumber; + using AggregationMethod_key16 = AggregationMethodOneNumber; + using AggregationMethod_key32 = AggregationMethodOneNumber; + using AggregationMethod_key64 = AggregationMethodOneNumber; + using AggregationMethod_key_int256 = AggregationMethodOneNumber; + using AggregationMethod_key_string = AggregationMethodStringNoCache; + using AggregationMethod_one_key_strbin = AggregationMethodOneKeyStringNoCache; + using AggregationMethod_one_key_strbinpadding = AggregationMethodOneKeyStringNoCache; + using AggregationMethod_key_fixed_string = AggregationMethodFixedStringNoCache; + using AggregationMethod_keys16 = AggregationMethodKeysFixed; + using AggregationMethod_keys32 = AggregationMethodKeysFixed; + using AggregationMethod_keys64 = AggregationMethodKeysFixed; + using AggregationMethod_keys128 = AggregationMethodKeysFixed; + using AggregationMethod_keys256 = AggregationMethodKeysFixed; + using AggregationMethod_serialized = AggregationMethodSerialized; + using AggregationMethod_key32_two_level = AggregationMethodOneNumber; + using AggregationMethod_key64_two_level = AggregationMethodOneNumber; + using AggregationMethod_key_int256_two_level = AggregationMethodOneNumber; + using AggregationMethod_key_string_two_level = AggregationMethodStringNoCache; + using AggregationMethod_one_key_strbin_two_level = AggregationMethodOneKeyStringNoCache; + using AggregationMethod_one_key_strbinpadding_two_level = AggregationMethodOneKeyStringNoCache; + using AggregationMethod_key_fixed_string_two_level = AggregationMethodFixedStringNoCache; + using AggregationMethod_keys32_two_level = AggregationMethodKeysFixed; + using AggregationMethod_keys64_two_level = AggregationMethodKeysFixed; + using AggregationMethod_keys128_two_level = AggregationMethodKeysFixed; + using AggregationMethod_keys256_two_level = AggregationMethodKeysFixed; + using AggregationMethod_serialized_two_level = AggregationMethodSerialized; + using AggregationMethod_key64_hash64 = AggregationMethodOneNumber; + using AggregationMethod_key_string_hash64 = AggregationMethodStringNoCache; + using AggregationMethod_key_fixed_string_hash64 = AggregationMethodFixedString; + using AggregationMethod_keys128_hash64 = AggregationMethodKeysFixed; + using AggregationMethod_keys256_hash64 = AggregationMethodKeysFixed; + using AggregationMethod_serialized_hash64 = AggregationMethodSerialized; /// Support for nullable keys. - std::unique_ptr> nullable_keys128; - std::unique_ptr> nullable_keys256; - std::unique_ptr> nullable_keys128_two_level; - std::unique_ptr> nullable_keys256_two_level; + using AggregationMethod_nullable_keys128 = AggregationMethodKeysFixed; + using AggregationMethod_nullable_keys256 = AggregationMethodKeysFixed; + using AggregationMethod_nullable_keys128_two_level = AggregationMethodKeysFixed; + using AggregationMethod_nullable_keys256_two_level = AggregationMethodKeysFixed; + + // 2 keys + using AggregationMethod_two_keys_num64_strbin = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_num64_strbinpadding = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbin_num64 = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbin_strbin = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbinpadding_num64 = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbinpadding_strbinpadding = AggregationMethodFastPathTwoKeysNoCache; + + using AggregationMethod_two_keys_num64_strbin_two_level = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_num64_strbinpadding_two_level = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbin_num64_two_level = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbin_strbin_two_level = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbinpadding_num64_two_level = AggregationMethodFastPathTwoKeysNoCache; + using AggregationMethod_two_keys_strbinpadding_strbinpadding_two_level = AggregationMethodFastPathTwoKeysNoCache; + + // 3 keys + // TODO: use 3 keys if necessary /// In this and similar macros, the option without_key is not considered. -#define APPLY_FOR_AGGREGATED_VARIANTS(M) \ - M(key8, false) \ - M(key16, false) \ - M(key32, false) \ - M(key64, false) \ - M(key_string, false) \ - M(multi_key_string, false) \ - M(key_fixed_string, false) \ - M(keys16, false) \ - M(keys32, false) \ - M(keys64, false) \ - M(keys128, false) \ - M(keys256, false) \ - M(key_int256, false) \ - M(serialized, false) \ - M(key32_two_level, true) \ - M(key64_two_level, true) \ - M(key_int256_two_level, true) \ - M(key_string_two_level, true) \ - M(multi_key_string_two_level, true) \ - M(key_fixed_string_two_level, true) \ - M(keys32_two_level, true) \ - M(keys64_two_level, true) \ - M(keys128_two_level, true) \ - M(keys256_two_level, true) \ - M(serialized_two_level, true) \ - M(key64_hash64, false) \ - M(key_string_hash64, false) \ - M(key_fixed_string_hash64, false) \ - M(keys128_hash64, false) \ - M(keys256_hash64, false) \ - M(serialized_hash64, false) \ - M(nullable_keys128, false) \ - M(nullable_keys256, false) \ - M(nullable_keys128_two_level, true) \ - M(nullable_keys256_two_level, true) +#define APPLY_FOR_AGGREGATED_VARIANTS(M) \ + M(key8, false) \ + M(key16, false) \ + M(key32, false) \ + M(key64, false) \ + M(key_string, false) \ + M(key_fixed_string, false) \ + M(keys16, false) \ + M(keys32, false) \ + M(keys64, false) \ + M(keys128, false) \ + M(keys256, false) \ + M(key_int256, false) \ + M(serialized, false) \ + M(key64_hash64, false) \ + M(key_string_hash64, false) \ + M(key_fixed_string_hash64, false) \ + M(keys128_hash64, false) \ + M(keys256_hash64, false) \ + M(serialized_hash64, false) \ + M(nullable_keys128, false) \ + M(nullable_keys256, false) \ + M(two_keys_num64_strbin, false) \ + M(two_keys_num64_strbinpadding, false) \ + M(two_keys_strbin_num64, false) \ + M(two_keys_strbin_strbin, false) \ + M(two_keys_strbinpadding_num64, false) \ + M(two_keys_strbinpadding_strbinpadding, false) \ + M(one_key_strbin, false) \ + M(one_key_strbinpadding, false) \ + M(key32_two_level, true) \ + M(key64_two_level, true) \ + M(key_int256_two_level, true) \ + M(key_string_two_level, true) \ + M(key_fixed_string_two_level, true) \ + M(keys32_two_level, true) \ + M(keys64_two_level, true) \ + M(keys128_two_level, true) \ + M(keys256_two_level, true) \ + M(serialized_two_level, true) \ + M(nullable_keys128_two_level, true) \ + M(nullable_keys256_two_level, true) \ + M(two_keys_num64_strbin_two_level, true) \ + M(two_keys_num64_strbinpadding_two_level, true) \ + M(two_keys_strbin_num64_two_level, true) \ + M(two_keys_strbin_strbin_two_level, true) \ + M(two_keys_strbinpadding_num64_two_level, true) \ + M(two_keys_strbinpadding_strbinpadding_two_level, true) \ + M(one_key_strbin_two_level, true) \ + M(one_key_strbinpadding_two_level, true) enum class Type { @@ -528,7 +681,10 @@ struct AggregatedDataVariants : private boost::noncopyable APPLY_FOR_AGGREGATED_VARIANTS(M) #undef M }; - Type type = Type::EMPTY; + + Type type{Type::EMPTY}; + + void destroyAggregationMethodImpl(); AggregatedDataVariants() : aggregates_pools(1, std::make_shared()) @@ -545,28 +701,7 @@ struct AggregatedDataVariants : private boost::noncopyable ~AggregatedDataVariants(); - void init(Type type_) - { - switch (type_) - { - case Type::EMPTY: - break; - case Type::without_key: - break; - -#define M(NAME, IS_TWO_LEVEL) \ - case Type::NAME: \ - NAME = std::make_unique(); \ - break; - APPLY_FOR_AGGREGATED_VARIANTS(M) -#undef M - - default: - throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); - } - - type = type_; - } + void init(Type variants_type); /// Number of rows (different keys). size_t size() const @@ -578,9 +713,13 @@ struct AggregatedDataVariants : private boost::noncopyable case Type::without_key: return 1; -#define M(NAME, IS_TWO_LEVEL) \ - case Type::NAME: \ - return NAME->data.size() + (without_key != nullptr); +#define M(NAME, IS_TWO_LEVEL) \ + case Type::NAME: \ + { \ + const auto * ptr = reinterpret_cast(aggregation_method_impl); \ + return ptr->data.size() + (without_key != nullptr); \ + } + APPLY_FOR_AGGREGATED_VARIANTS(M) #undef M @@ -599,9 +738,13 @@ struct AggregatedDataVariants : private boost::noncopyable case Type::without_key: return 1; -#define M(NAME, IS_TWO_LEVEL) \ - case Type::NAME: \ - return NAME->data.size(); +#define M(NAME, IS_TWO_LEVEL) \ + case Type::NAME: \ + { \ + const auto * ptr = reinterpret_cast(aggregation_method_impl); \ + return ptr->data.size(); \ + } + APPLY_FOR_AGGREGATED_VARIANTS(M) #undef M @@ -611,6 +754,10 @@ struct AggregatedDataVariants : private boost::noncopyable } const char * getMethodName() const + { + return getMethodName(type); + } + static const char * getMethodName(Type type) { switch (type) { @@ -661,9 +808,17 @@ struct AggregatedDataVariants : private boost::noncopyable M(keys128) \ M(keys256) \ M(serialized) \ - M(multi_key_string) \ M(nullable_keys128) \ - M(nullable_keys256) + M(nullable_keys256) \ + M(two_keys_num64_strbin) \ + M(two_keys_num64_strbinpadding) \ + M(two_keys_strbin_num64) \ + M(two_keys_strbin_strbin) \ + M(two_keys_strbinpadding_num64) \ + M(two_keys_strbinpadding_strbinpadding) \ + M(one_key_strbin) \ + M(one_key_strbinpadding) + #define APPLY_FOR_VARIANTS_NOT_CONVERTIBLE_TO_TWO_LEVEL(M) \ M(key8) \ @@ -698,20 +853,27 @@ struct AggregatedDataVariants : private boost::noncopyable void convertToTwoLevel(); -#define APPLY_FOR_VARIANTS_TWO_LEVEL(M) \ - M(key32_two_level) \ - M(key64_two_level) \ - M(key_int256_two_level) \ - M(key_string_two_level) \ - M(key_fixed_string_two_level) \ - M(keys32_two_level) \ - M(keys64_two_level) \ - M(keys128_two_level) \ - M(keys256_two_level) \ - M(serialized_two_level) \ - M(multi_key_string_two_level) \ - M(nullable_keys128_two_level) \ - M(nullable_keys256_two_level) +#define APPLY_FOR_VARIANTS_TWO_LEVEL(M) \ + M(key32_two_level) \ + M(key64_two_level) \ + M(key_int256_two_level) \ + M(key_string_two_level) \ + M(key_fixed_string_two_level) \ + M(keys32_two_level) \ + M(keys64_two_level) \ + M(keys128_two_level) \ + M(keys256_two_level) \ + M(serialized_two_level) \ + M(nullable_keys128_two_level) \ + M(nullable_keys256_two_level) \ + M(two_keys_num64_strbin_two_level) \ + M(two_keys_num64_strbinpadding_two_level) \ + M(two_keys_strbin_num64_two_level) \ + M(two_keys_strbin_strbin_two_level) \ + M(two_keys_strbinpadding_num64_two_level) \ + M(two_keys_strbinpadding_strbinpadding_two_level) \ + M(one_key_strbin_two_level) \ + M(one_key_strbinpadding_two_level) }; using AggregatedDataVariantsPtr = std::shared_ptr; @@ -887,7 +1049,7 @@ class Aggregator /** Set a function that checks whether the current task can be aborted. */ - void setCancellationHook(const CancellationHook cancellation_hook); + void setCancellationHook(CancellationHook cancellation_hook); /// For external aggregation. void writeToTemporaryFile(AggregatedDataVariants & data_variants, const FileProviderPtr & file_provider); @@ -920,6 +1082,8 @@ class Aggregator Params params; AggregatedDataVariants::Type method_chosen; + + Sizes key_sizes; AggregateFunctionsPlainPtrs aggregate_functions; @@ -931,13 +1095,13 @@ class Aggregator */ struct AggregateFunctionInstruction { - const IAggregateFunction * that; - IAggregateFunction::AddFunc func; - size_t state_offset; - const IColumn ** arguments; - const IAggregateFunction * batch_that; - const IColumn ** batch_arguments; - const UInt64 * offsets = nullptr; + const IAggregateFunction * that{}; + IAggregateFunction::AddFunc func{}; + size_t state_offset{}; + const IColumn ** arguments{}; + const IAggregateFunction * batch_that{}; + const IColumn ** batch_arguments{}; + const UInt64 * offsets{}; }; using AggregateFunctionInstructions = std::vector; @@ -961,7 +1125,7 @@ class Aggregator const LoggerPtr log; /// Returns true if you can abort the current task. - CancellationHook isCancelled; + CancellationHook is_cancelled; /// For external aggregation. TemporaryFiles temporary_files; @@ -1163,21 +1327,21 @@ class Aggregator bool checkLimits(size_t result_size, bool & no_more_keys) const; }; - /** Get the aggregation variant by its type. */ template Method & getDataVariant(AggregatedDataVariants & variants); -#define M(NAME, IS_TWO_LEVEL) \ - template <> \ - inline decltype(AggregatedDataVariants::NAME)::element_type & getDataVariant(AggregatedDataVariants & variants) \ - { \ - return *variants.NAME; \ +#define M(NAME, IS_TWO_LEVEL) \ + template <> \ + inline AggregationMethodName(NAME) & /*NOLINT*/ \ + getDataVariant(AggregatedDataVariants & variants) \ + { \ + return *reinterpret_cast(variants.aggregation_method_impl); \ } APPLY_FOR_AGGREGATED_VARIANTS(M) #undef M - +#undef AggregationMethodName } // namespace DB diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 28bb81ec4cf..3e5d621254d 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -169,10 +169,15 @@ void Join::setBuildTableState(BuildTableState state_) build_table_cv.notify_all(); } +bool CanAsColumnString(const IColumn * column) +{ + return typeid_cast(column) + || (column->isColumnConst() && typeid_cast(&static_cast(column)->getDataColumn())); +} -Join::Type Join::chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes) +Join::Type Join::chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes) const { - size_t keys_size = key_columns.size(); + const size_t keys_size = key_columns.size(); if (keys_size == 0) return Type::CROSS; @@ -215,10 +220,33 @@ Join::Type Join::chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_siz return Type::keys256; /// If there is single string key, use hash table of it's values. - if (keys_size == 1 - && (typeid_cast(key_columns[0]) - || (key_columns[0]->isColumnConst() && typeid_cast(&static_cast(key_columns[0])->getDataColumn())))) - return Type::key_string; + if (keys_size == 1 && CanAsColumnString(key_columns[0])) + { + if (collators.empty() || !collators[0]) + return Type::key_strbin; + else + { + switch (collators[0]->getCollatorType()) + { + case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN: + case TiDB::ITiDBCollator::CollatorType::UTF8_BIN: + case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN: + case TiDB::ITiDBCollator::CollatorType::ASCII_BIN: + { + return Type::key_strbinpadding; + } + case TiDB::ITiDBCollator::CollatorType::BINARY: + { + return Type::key_strbin; + } + default: + { + // for CI COLLATION, use original way + return Type::key_string; + } + } + } + } if (keys_size == 1 && typeid_cast(key_columns[0])) return Type::key_fixed_string; @@ -322,6 +350,16 @@ struct KeyGetterForTypeImpl using Type = ColumnsHashing::HashMethodString; }; template +struct KeyGetterForTypeImpl +{ + using Type = ColumnsHashing::HashMethodStringBin; +}; +template +struct KeyGetterForTypeImpl +{ + using Type = ColumnsHashing::HashMethodStringBin; +}; +template struct KeyGetterForTypeImpl { using Type = ColumnsHashing::HashMethodFixedString; diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index f30479b6f46..836b631bb77 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -228,6 +228,8 @@ class Join M(key32) \ M(key64) \ M(key_string) \ + M(key_strbinpadding) \ + M(key_strbin) \ M(key_fixed_string) \ M(keys128) \ M(keys256) \ @@ -253,10 +255,13 @@ class Join std::unique_ptr>> key32; std::unique_ptr>> key64; std::unique_ptr> key_string; + std::unique_ptr> key_strbinpadding; + std::unique_ptr> key_strbin; std::unique_ptr> key_fixed_string; std::unique_ptr>> keys128; std::unique_ptr>> keys256; std::unique_ptr> serialized; + // TODO: add more cases like Aggregator }; using MapsAny = MapsTemplate>; @@ -320,7 +325,7 @@ class Join private: Type type = Type::EMPTY; - static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); + Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes) const; Sizes key_sizes; diff --git a/dbms/src/Interpreters/Set.cpp b/dbms/src/Interpreters/Set.cpp index ffe0aab62d5..a0693cfa41f 100644 --- a/dbms/src/Interpreters/Set.cpp +++ b/dbms/src/Interpreters/Set.cpp @@ -117,7 +117,7 @@ void Set::setHeader(const Block & block) extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); /// Choose data structure to use for the set. - data.init(data.chooseMethod(key_columns, key_sizes)); + data.init(data.chooseMethod(key_columns, key_sizes, collators)); } @@ -183,7 +183,7 @@ bool Set::insertFromBlock(const Block & block, bool fill_set_elements) static Field extractValueFromNode(ASTPtr & node, const IDataType & type, const Context & context) { - if (ASTLiteral * lit = typeid_cast(node.get())) + if (auto * lit = typeid_cast(node.get())) { return convertFieldToType(lit->value, type); } @@ -222,7 +222,7 @@ void Set::createFromAST(const DataTypes & types, ASTPtr node, const Context & co else setContainsNullValue(true); } - else if (ASTFunction * func = typeid_cast(elem.get())) + else if (auto * func = typeid_cast(elem.get())) { if (func->name != "tuple") throw Exception("Incorrect element of set. Must be tuple.", ErrorCodes::INCORRECT_ELEMENT_OF_SET); @@ -274,19 +274,19 @@ std::vector Set::createFromDAGExpr(const DataTypes & types, setHeader(header); MutableColumns columns = header.cloneEmptyColumns(); - std::vector remainingExprs; + std::vector remaining_exprs; // if left arg is null constant, just return without decode children expr if (types[0]->onlyNull()) - return remainingExprs; + return remaining_exprs; for (int i = 1; i < expr.children_size(); i++) { - auto & child = expr.children(i); + const auto & child = expr.children(i); // todo support constant expression by constant folding if (!isLiteralExpr(child)) { - remainingExprs.push_back(&child); + remaining_exprs.push_back(&child); continue; } Field value = decodeLiteral(child); @@ -301,7 +301,7 @@ std::vector Set::createFromDAGExpr(const DataTypes & types, Block block = header.cloneWithColumns(std::move(columns)); insertFromBlock(block, fill_set_elements); - return remainingExprs; + return remaining_exprs; } ColumnPtr Set::execute(const Block & block, bool negative) const diff --git a/dbms/src/Interpreters/Set.h b/dbms/src/Interpreters/Set.h index dae9f3f9374..223c18fd4f7 100644 --- a/dbms/src/Interpreters/Set.h +++ b/dbms/src/Interpreters/Set.h @@ -47,10 +47,11 @@ using FunctionBasePtr = std::shared_ptr; class Set { public: - Set(const SizeLimits & limits) + explicit Set(const SizeLimits & limits, TiDB::TiDBCollators && collators_ = {}) : log(&Poco::Logger::get("Set")) , limits(limits) , set_elements(std::make_unique()) + , collators(std::move(collators_)) { } @@ -74,7 +75,7 @@ class Set /** Create a Set from stream. * Call setHeader, then call insertFromBlock for each block. */ - void setHeader(const Block & header); + void setHeader(const Block &); /// Returns false, if some limit was exceeded and no need to insert more data. bool insertFromBlock(const Block & block, bool fill_set_elements); @@ -94,10 +95,8 @@ class Set void setContainsNullValue(bool contains_null_value_) { contains_null_value = contains_null_value_; } bool containsNullValue() const { return contains_null_value; } - void setCollators(TiDB::TiDBCollators & collators_) { collators = collators_; } - private: - size_t keys_size; + size_t keys_size{}; Sizes key_sizes; SetVariants data; diff --git a/dbms/src/Interpreters/SetVariants.cpp b/dbms/src/Interpreters/SetVariants.cpp index 95a597319a4..c6381d6d098 100644 --- a/dbms/src/Interpreters/SetVariants.cpp +++ b/dbms/src/Interpreters/SetVariants.cpp @@ -35,9 +35,9 @@ void SetVariantsTemplate::init(Type type_) case Type::EMPTY: break; -#define M(NAME) \ - case Type::NAME: \ - NAME = std::make_unique(); \ +#define M(NAME) \ + case Type::NAME: \ + (NAME) = std::make_unique(); \ break; APPLY_FOR_SET_VARIANTS(M) #undef M @@ -57,7 +57,7 @@ size_t SetVariantsTemplate::getTotalRowCount() const #define M(NAME) \ case Type::NAME: \ - return NAME->data.size(); + return (NAME)->data.size(); APPLY_FOR_SET_VARIANTS(M) #undef M @@ -76,7 +76,7 @@ size_t SetVariantsTemplate::getTotalByteCount() const #define M(NAME) \ case Type::NAME: \ - return NAME->data.getBufferSizeInBytes(); + return (NAME)->data.getBufferSizeInBytes(); APPLY_FOR_SET_VARIANTS(M) #undef M @@ -86,7 +86,7 @@ size_t SetVariantsTemplate::getTotalByteCount() const } template -typename SetVariantsTemplate::Type SetVariantsTemplate::chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes) +typename SetVariantsTemplate::Type SetVariantsTemplate::chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes, const TiDB::TiDBCollators & collators) { /// Check if at least one of the specified keys is nullable. /// Create a set of nested key columns from the corresponding key columns. @@ -100,7 +100,7 @@ typename SetVariantsTemplate::Type SetVariantsTemplate::choose { if (col->isColumnNullable()) { - const ColumnNullable & nullable_col = static_cast(*col); + const auto & nullable_col = static_cast(*col); nested_key_columns.push_back(&nullable_col.getNestedColumn()); has_nullable_key = true; } @@ -186,7 +186,32 @@ typename SetVariantsTemplate::Type SetVariantsTemplate::choose if (keys_size == 1 && (typeid_cast(nested_key_columns[0]) || (nested_key_columns[0]->isColumnConst() && typeid_cast(&static_cast(nested_key_columns[0])->getDataColumn())))) - return Type::key_string; + { + if (collators.empty() || !collators[0]) + return Type::key_strbin; + else + { + switch (collators[0]->getCollatorType()) + { + case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN: + case TiDB::ITiDBCollator::CollatorType::UTF8_BIN: + case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN: + case TiDB::ITiDBCollator::CollatorType::ASCII_BIN: + { + return Type::key_strbinpadding; + } + case TiDB::ITiDBCollator::CollatorType::BINARY: + { + return Type::key_strbin; + } + default: + { + // for CI COLLATION, use original way + return Type::key_string; + } + } + } + } if (keys_size == 1 && typeid_cast(nested_key_columns[0])) return Type::key_fixed_string; diff --git a/dbms/src/Interpreters/SetVariants.h b/dbms/src/Interpreters/SetVariants.h index 58f102555b6..07120d9afcc 100644 --- a/dbms/src/Interpreters/SetVariants.h +++ b/dbms/src/Interpreters/SetVariants.h @@ -45,11 +45,6 @@ struct SetMethodOneNumber use_cache>; }; -namespace GeneralCI -{ -using WeightType = uint16_t; -} // namespace GeneralCI - /// For the case where there is one string key. template struct SetMethodString @@ -62,6 +57,17 @@ struct SetMethodString using State = ColumnsHashing::HashMethodString; }; +template +struct SetMethodStringBinNoCache +{ + using Data = TData; + using Key = typename Data::key_type; + + Data data; + + using State = ColumnsHashing::HashMethodStringBin; +}; + /// For the case when there is one fixed-length string key. template struct SetMethodFixedString @@ -214,6 +220,8 @@ struct NonClearableSet std::unique_ptr>>> key32; std::unique_ptr>>> key64; std::unique_ptr>> key_string; + std::unique_ptr, true>> key_strbinpadding; + std::unique_ptr, false>> key_strbin; std::unique_ptr>> key_fixed_string; std::unique_ptr>>> keys128; std::unique_ptr>>> keys256; @@ -237,6 +245,8 @@ struct ClearableSet std::unique_ptr>>> key32; std::unique_ptr>>> key64; std::unique_ptr>> key_string; + std::unique_ptr, true>> key_strbinpadding; + std::unique_ptr, false>> key_strbin; std::unique_ptr>> key_fixed_string; std::unique_ptr>>> keys128; std::unique_ptr>>> keys256; @@ -262,6 +272,8 @@ struct SetVariantsTemplate : public Variant M(key32) \ M(key64) \ M(key_string) \ + M(key_strbinpadding) \ + M(key_strbin) \ M(key_fixed_string) \ M(keys128) \ M(keys256) \ @@ -286,7 +298,7 @@ struct SetVariantsTemplate : public Variant bool empty() const { return type == Type::EMPTY; } - static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); + static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes, const TiDB::TiDBCollators & collators = {}); void init(Type type_); diff --git a/dbms/src/Interpreters/sortBlock.cpp b/dbms/src/Interpreters/sortBlock.cpp index 655fe17af0d..9995329b833 100644 --- a/dbms/src/Interpreters/sortBlock.cpp +++ b/dbms/src/Interpreters/sortBlock.cpp @@ -70,8 +70,8 @@ ALWAYS_INLINE static inline bool NeedCollation(const IColumn * column, const Sor #define APPLY_FOR_TYPE(M) \ M(UInt64) \ M(Int64) \ + M(StringBin) \ M(StringBinPadding) \ - M(StringBinNoPadding) \ M(StringWithCollatorGeneric) #define CONCAT(x, y) x##y @@ -124,7 +124,7 @@ struct ColumnStringCompare { ret = BinCollatorCompare(str_a.data, str_a.size, str_b.data, str_b.size); } - else if constexpr (type == FastPathType::StringBinNoPadding) + else if constexpr (type == FastPathType::StringBin) { ret = BinCollatorCompare(str_a.data, str_a.size, str_b.data, str_b.size); } @@ -140,7 +140,7 @@ struct ColumnStringCompare using ColumnCompareUInt64 = ColumnVecCompare; using ColumnCompareInt64 = ColumnVecCompare; using ColumnCompareStringBinPadding = ColumnStringCompare; -using ColumnCompareStringBinNoPadding = ColumnStringCompare; +using ColumnCompareStringBin = ColumnStringCompare; using ColumnCompareStringWithCollatorGeneric = ColumnStringCompare; // only for uint64, int64, string @@ -235,7 +235,7 @@ struct FastSortDesc : boost::noncopyable } case TiDB::ITiDBCollator::CollatorType::BINARY: { - addFastPathType(FastPathType::StringBinNoPadding); + addFastPathType(FastPathType::StringBin); break; } default: