diff --git a/src/include/optimizer/factorization_rewriter.h b/src/include/optimizer/factorization_rewriter.h new file mode 100644 index 00000000000..9f171b9f216 --- /dev/null +++ b/src/include/optimizer/factorization_rewriter.h @@ -0,0 +1,38 @@ +#include "planner/logical_plan/logical_plan.h" + +namespace kuzu { +namespace optimizer { + +class FactorizationRewriter { +public: + void rewrite(planner::LogicalPlan* plan); + +private: + void visitOperator(planner::LogicalOperator* op); + void visitExtend(planner::LogicalOperator* op); + void visitHashJoin(planner::LogicalOperator* op); + void visitIntersect(planner::LogicalOperator* op); + void visitProjection(planner::LogicalOperator* op); + void visitAggregate(planner::LogicalOperator* op); + void visitOrderBy(planner::LogicalOperator* op); + void visitSkip(planner::LogicalOperator* op); + void visitLimit(planner::LogicalOperator* op); + void visitDistinct(planner::LogicalOperator* op); + void visitUnwind(planner::LogicalOperator* op); + void visitUnion(planner::LogicalOperator* op); + void visitFilter(planner::LogicalOperator* op); + void visitSetNodeProperty(planner::LogicalOperator* op); + void visitSetRelProperty(planner::LogicalOperator* op); + void visitDeleteRel(planner::LogicalOperator* op); + void visitCreateNode(planner::LogicalOperator* op); + void visitCreateRel(planner::LogicalOperator* op); + + std::shared_ptr appendFlattens( + std::shared_ptr op, + const std::unordered_set& groupsPos); + std::shared_ptr appendFlattenIfNecessary( + std::shared_ptr op, planner::f_group_pos groupPos); +}; + +} // namespace optimizer +} // namespace kuzu diff --git a/src/include/optimizer/remove_factorization_rewriter.h b/src/include/optimizer/remove_factorization_rewriter.h new file mode 100644 index 00000000000..3c32b74a68e --- /dev/null +++ b/src/include/optimizer/remove_factorization_rewriter.h @@ -0,0 +1,18 @@ +#include "planner/logical_plan/logical_plan.h" + +namespace kuzu { +namespace optimizer { + +class RemoveFactorizationRewriter { +public: + void rewrite(planner::LogicalPlan* plan); + +private: + std::shared_ptr rewriteOperator( + std::shared_ptr op); + + bool subPlanHasFlatten(planner::LogicalOperator* op); +}; + +} // namespace optimizer +} // namespace kuzu diff --git a/src/include/planner/join_order_enumerator.h b/src/include/planner/join_order_enumerator.h index dab571dd40d..ac919c45a85 100644 --- a/src/include/planner/join_order_enumerator.h +++ b/src/include/planner/join_order_enumerator.h @@ -113,8 +113,6 @@ class JoinOrderEnumerator { void appendIndexScanNode(std::shared_ptr& node, std::shared_ptr indexExpression, LogicalPlan& plan); - bool needFlatInput( - RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction); bool needExtendToNewGroup( RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction); void appendExtend(std::shared_ptr boundNode, @@ -129,8 +127,8 @@ class JoinOrderEnumerator { static void appendMarkJoin(const binder::expression_vector& joinNodeIDs, const std::shared_ptr& mark, bool isProbeAcc, LogicalPlan& probePlan, LogicalPlan& buildPlan); - static void appendIntersect(const std::shared_ptr& intersectNode, - std::vector>& boundNodes, LogicalPlan& probePlan, + static void appendIntersect(const std::shared_ptr& intersectNodeID, + binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan, std::vector>& buildPlans); static void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan); diff --git a/src/include/planner/logical_plan/logical_operator/base_logical_operator.h b/src/include/planner/logical_plan/logical_operator/base_logical_operator.h index b7fe405cc63..da24bdcaa8e 100644 --- a/src/include/planner/logical_plan/logical_operator/base_logical_operator.h +++ b/src/include/planner/logical_plan/logical_operator/base_logical_operator.h @@ -69,6 +69,7 @@ class LogicalOperator { // Used for operators with more than two children e.g. Union inline void addChild(std::shared_ptr op) { children.push_back(std::move(op)); } inline std::shared_ptr getChild(uint64_t idx) const { return children[idx]; } + inline std::vector> getChildren() const { return children; } inline void setChild(uint64_t idx, std::shared_ptr child) { children[idx] = std::move(child); } diff --git a/src/include/planner/logical_plan/logical_operator/flatten_resolver.h b/src/include/planner/logical_plan/logical_operator/flatten_resolver.h new file mode 100644 index 00000000000..d1831ce8bc3 --- /dev/null +++ b/src/include/planner/logical_plan/logical_operator/flatten_resolver.h @@ -0,0 +1,19 @@ +#pragma once + +#include "planner/logical_plan/logical_operator/base_logical_operator.h" + +namespace kuzu { +namespace planner { +namespace factorization { + +struct FlattenAllButOne { + static f_group_pos_set getGroupsPosToFlatten(const f_group_pos_set& groupsPos, Schema* schema); +}; + +struct FlattenAll { + static f_group_pos_set getGroupsPosToFlatten(const f_group_pos_set& groupsPos, Schema* schema); +}; + +} // namespace factorization +} // namespace planner +} // namespace kuzu diff --git a/src/include/planner/logical_plan/logical_operator/logical_aggregate.h b/src/include/planner/logical_plan/logical_operator/logical_aggregate.h index 76222207f11..5544fefaced 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_aggregate.h +++ b/src/include/planner/logical_plan/logical_operator/logical_aggregate.h @@ -13,6 +13,9 @@ class LogicalAggregate : public LogicalOperator { expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move( expressionsToAggregate)} {} + f_group_pos_set getGroupsPosToFlattenForGroupBy(); + f_group_pos_set getGroupsPosToFlattenForAggregate(); + void computeSchema() override; std::string getExpressionsForPrinting() const override; @@ -31,6 +34,9 @@ class LogicalAggregate : public LogicalOperator { expressionsToGroupBy, expressionsToAggregate, children[0]->copy()); } +private: + bool hasDistinctAggregate(); + private: binder::expression_vector expressionsToGroupBy; binder::expression_vector expressionsToAggregate; diff --git a/src/include/planner/logical_plan/logical_operator/logical_create.h b/src/include/planner/logical_plan/logical_operator/logical_create.h index dda6e14f417..5184135c6b2 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_create.h +++ b/src/include/planner/logical_plan/logical_operator/logical_create.h @@ -1,5 +1,6 @@ #pragma once +#include "flatten_resolver.h" #include "logical_update.h" namespace kuzu { @@ -15,6 +16,14 @@ class LogicalCreateNode : public LogicalUpdateNode { void computeSchema() override; + inline f_group_pos_set getGroupsPosToFlatten() { + // Flatten all inputs. E.g. MATCH (a) CREATE (b). We need to create b for each tuple in the + // match clause. This is to simplify operator implementation. + auto childSchema = children[0]->getSchema(); + return factorization::FlattenAll::getGroupsPosToFlatten( + childSchema->getGroupsPosInScope(), childSchema); + } + inline std::shared_ptr getPrimaryKey(size_t idx) const { return primaryKeys[idx]; } @@ -36,6 +45,12 @@ class LogicalCreateRel : public LogicalUpdateRel { setItemsPerRel{std::move(setItemsPerRel)} {} ~LogicalCreateRel() override = default; + inline f_group_pos_set getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return factorization::FlattenAll::getGroupsPosToFlatten( + childSchema->getGroupsPosInScope(), childSchema); + } + inline std::vector getSetItems(uint32_t idx) const { return setItemsPerRel[idx]; } diff --git a/src/include/planner/logical_plan/logical_operator/logical_delete.h b/src/include/planner/logical_plan/logical_operator/logical_delete.h index a2a2a50a755..fd5f0c61aa1 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_delete.h +++ b/src/include/planner/logical_plan/logical_operator/logical_delete.h @@ -1,6 +1,7 @@ #pragma once #include "logical_update.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" namespace kuzu { namespace planner { @@ -32,6 +33,15 @@ class LogicalDeleteRel : public LogicalUpdateRel { : LogicalUpdateRel{LogicalOperatorType::DELETE_REL, std::move(rels), std::move(child)} {} ~LogicalDeleteRel() override = default; + inline f_group_pos_set getGroupsPosToFlatten(uint32_t relIdx) { + f_group_pos_set result; + auto rel = rels[relIdx]; + auto childSchema = children[0]->getSchema(); + result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalIDProperty())); + result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalIDProperty())); + return factorization::FlattenAll::getGroupsPosToFlatten(result, childSchema); + } + inline std::unique_ptr copy() override { return make_unique(rels, children[0]->copy()); } diff --git a/src/include/planner/logical_plan/logical_operator/logical_distinct.h b/src/include/planner/logical_plan/logical_operator/logical_distinct.h index 289281565bc..afbc6a6f440 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_distinct.h +++ b/src/include/planner/logical_plan/logical_operator/logical_distinct.h @@ -13,6 +13,8 @@ class LogicalDistinct : public LogicalOperator { : LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)}, expressionsToDistinct{std::move(expressionsToDistinct)} {} + f_group_pos_set getGroupsPosToFlatten(); + void computeSchema() override; std::string getExpressionsForPrinting() const override; diff --git a/src/include/planner/logical_plan/logical_operator/logical_extend.h b/src/include/planner/logical_plan/logical_operator/logical_extend.h index 1941590fbab..6db00c67565 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_extend.h +++ b/src/include/planner/logical_plan/logical_operator/logical_extend.h @@ -17,6 +17,8 @@ class LogicalExtend : public LogicalOperator { nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction}, properties{std::move(properties)}, extendToNewGroup{extendToNewGroup} {} + f_group_pos_set getGroupsPosToFlatten(); + void computeSchema() override; inline std::string getExpressionsForPrinting() const override { diff --git a/src/include/planner/logical_plan/logical_operator/logical_filter.h b/src/include/planner/logical_plan/logical_operator/logical_filter.h index 86e6a95f32f..7e8749cba8d 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_filter.h +++ b/src/include/planner/logical_plan/logical_operator/logical_filter.h @@ -13,6 +13,8 @@ class LogicalFilter : public LogicalOperator { : LogicalOperator{LogicalOperatorType::FILTER, std::move(child)}, expression{std::move( expression)} {} + f_group_pos_set getGroupsPosToFlatten(); + inline void computeSchema() override { copyChildSchema(0); } inline std::string getExpressionsForPrinting() const override { diff --git a/src/include/planner/logical_plan/logical_operator/logical_flatten.h b/src/include/planner/logical_plan/logical_operator/logical_flatten.h index 3a4f400e043..44cb401bc62 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_flatten.h +++ b/src/include/planner/logical_plan/logical_operator/logical_flatten.h @@ -8,25 +8,21 @@ namespace planner { class LogicalFlatten : public LogicalOperator { public: - LogicalFlatten( - std::shared_ptr expression, std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, expression{std::move( - expression)} {} + LogicalFlatten(f_group_pos groupPos, std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, groupPos{groupPos} {} void computeSchema() override; - inline std::string getExpressionsForPrinting() const override { - return expression->getUniqueName(); - } + inline std::string getExpressionsForPrinting() const override { return std::string{}; } - inline std::shared_ptr getExpression() const { return expression; } + inline f_group_pos getGroupPos() const { return groupPos; } inline std::unique_ptr copy() override { - return make_unique(expression, children[0]->copy()); + return make_unique(groupPos, children[0]->copy()); } private: - std::shared_ptr expression; + f_group_pos groupPos; }; } // namespace planner diff --git a/src/include/planner/logical_plan/logical_operator/logical_hash_join.h b/src/include/planner/logical_plan/logical_operator/logical_hash_join.h index 5645ffe9c1d..1bf73a09d7a 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_hash_join.h +++ b/src/include/planner/logical_plan/logical_operator/logical_hash_join.h @@ -39,6 +39,9 @@ class LogicalHashJoin : public LogicalOperator { joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)}, isProbeAcc{isProbeAcc}, expressionsToMaterialize{std::move(expressionsToMaterialize)} {} + f_group_pos_set getGroupsPosToFlattenOnProbeSide(); + f_group_pos_set getGroupsPosToFlattenOnBuildSide(); + void computeSchema() override; inline std::string getExpressionsForPrinting() const override { @@ -63,6 +66,18 @@ class LogicalHashJoin : public LogicalOperator { expressionsToMaterialize, children[0]->copy(), children[1]->copy()); } +private: + // Flat probe side key group in either of the following two cases: + // 1. there are multiple join nodes; + // 2. if the build side contains more than one group or the build side has projected out data + // chunks, which may increase the multiplicity of data chunks in the build side. The key is to + // keep probe side key unflat only when we know that there is only 0 or 1 match for each key. + // TODO(Guodong): when the build side has only flat payloads, we should consider getting rid of + // flattening probe key, instead duplicating keys as in vectorized processing if necessary. + bool requireFlatProbeKeys(); + + bool isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID); + private: binder::expression_vector joinNodeIDs; common::JoinType joinType; diff --git a/src/include/planner/logical_plan/logical_operator/logical_intersect.h b/src/include/planner/logical_plan/logical_operator/logical_intersect.h index 50409911f48..041d3d6f45c 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_intersect.h +++ b/src/include/planner/logical_plan/logical_operator/logical_intersect.h @@ -33,6 +33,9 @@ class LogicalIntersect : public LogicalOperator { } } + f_group_pos_set getGroupsPosToFlattenOnProbeSide(); + f_group_pos_set getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx); + void computeSchema() override; std::string getExpressionsForPrinting() const override { return intersectNodeID->getRawName(); } diff --git a/src/include/planner/logical_plan/logical_operator/logical_limit.h b/src/include/planner/logical_plan/logical_operator/logical_limit.h index 2a0c31e0b0e..97a8e7f199f 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_limit.h +++ b/src/include/planner/logical_plan/logical_operator/logical_limit.h @@ -1,6 +1,7 @@ #pragma once #include "base_logical_operator.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" namespace kuzu { namespace planner { @@ -10,6 +11,8 @@ class LogicalLimit : public LogicalOperator { LogicalLimit(uint64_t limitNumber, std::shared_ptr child) : LogicalOperator{LogicalOperatorType::LIMIT, std::move(child)}, limitNumber{limitNumber} {} + f_group_pos_set getGroupsPosToFlatten(); + inline void computeSchema() override { copyChildSchema(0); } inline std::string getExpressionsForPrinting() const override { diff --git a/src/include/planner/logical_plan/logical_operator/logical_order_by.h b/src/include/planner/logical_plan/logical_operator/logical_order_by.h index f2208d463ed..ebe2c2af6a5 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_order_by.h +++ b/src/include/planner/logical_plan/logical_operator/logical_order_by.h @@ -13,6 +13,8 @@ class LogicalOrderBy : public LogicalOperator { expressionsToOrderBy{std::move(expressionsToOrderBy)}, isAscOrders{std::move(sortOrders)}, expressionsToMaterialize{std::move(expressionsToMaterialize)} {} + f_group_pos_set getGroupsPosToFlatten(); + void computeSchema() override; inline std::string getExpressionsForPrinting() const override { diff --git a/src/include/planner/logical_plan/logical_operator/logical_set.h b/src/include/planner/logical_plan/logical_operator/logical_set.h index a0c22688c72..89e4d1d225d 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_set.h +++ b/src/include/planner/logical_plan/logical_operator/logical_set.h @@ -1,6 +1,7 @@ #pragma once #include "logical_update.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" namespace kuzu { namespace planner { @@ -39,6 +40,8 @@ class LogicalSetRelProperty : public LogicalUpdateRel { std::move(child)}, setItems{std::move(setItems)} {} + f_group_pos_set getGroupsPosToFlatten(uint32_t setItemIdx); + inline std::string getExpressionsForPrinting() const override { std::string result; for (auto& [lhs, rhs] : setItems) { diff --git a/src/include/planner/logical_plan/logical_operator/logical_skip.h b/src/include/planner/logical_plan/logical_operator/logical_skip.h index 6e5cb4aa801..fc84dc32b43 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_skip.h +++ b/src/include/planner/logical_plan/logical_operator/logical_skip.h @@ -1,6 +1,7 @@ #pragma once #include "base_logical_operator.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" namespace kuzu { namespace planner { @@ -10,6 +11,8 @@ class LogicalSkip : public LogicalOperator { LogicalSkip(uint64_t skipNumber, std::shared_ptr child) : LogicalOperator(LogicalOperatorType::SKIP, std::move(child)), skipNumber{skipNumber} {} + f_group_pos_set getGroupsPosToFlatten(); + inline void computeSchema() override { copyChildSchema(0); } inline std::string getExpressionsForPrinting() const override { diff --git a/src/include/planner/logical_plan/logical_operator/logical_union.h b/src/include/planner/logical_plan/logical_operator/logical_union.h index 522ad54459e..e19783ae985 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_union.h +++ b/src/include/planner/logical_plan/logical_operator/logical_union.h @@ -12,9 +12,11 @@ class LogicalUnion : public LogicalOperator { : LogicalOperator{LogicalOperatorType::UNION_ALL, std::move(children)}, expressionsToUnion{std::move(expressions)} {} + f_group_pos_set getGroupsPosToFlatten(uint32_t childIdx); + void computeSchema() override; - inline std::string getExpressionsForPrinting() const override { return std::string(); } + inline std::string getExpressionsForPrinting() const override { return std::string{}; } inline binder::expression_vector getExpressionsToUnion() { return expressionsToUnion; } @@ -22,6 +24,11 @@ class LogicalUnion : public LogicalOperator { std::unique_ptr copy() override; +private: + // If an expression to union has different flat/unflat state in different child, we + // need to flatten that expression in all the single queries. + bool requireFlatExpression(uint32_t expressionIdx); + private: binder::expression_vector expressionsToUnion; }; diff --git a/src/include/planner/logical_plan/logical_operator/logical_unwind.h b/src/include/planner/logical_plan/logical_operator/logical_unwind.h index 8daff1adabe..4237b3228ce 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_unwind.h +++ b/src/include/planner/logical_plan/logical_operator/logical_unwind.h @@ -13,6 +13,8 @@ class LogicalUnwind : public LogicalOperator { : LogicalOperator{LogicalOperatorType::UNWIND, std::move(childOperator)}, expression{std::move(expression)}, aliasExpression{std::move(aliasExpression)} {} + f_group_pos_set getGroupsPosToFlatten(); + void computeSchema() override; inline std::shared_ptr getExpression() { return expression; } diff --git a/src/include/planner/logical_plan/logical_operator/schema.h b/src/include/planner/logical_plan/logical_operator/schema.h index bd0660d2b75..8850aa1c90e 100644 --- a/src/include/planner/logical_plan/logical_operator/schema.h +++ b/src/include/planner/logical_plan/logical_operator/schema.h @@ -8,6 +8,7 @@ namespace kuzu { namespace planner { typedef uint32_t f_group_pos; +typedef std::unordered_set f_group_pos_set; constexpr f_group_pos INVALID_F_GROUP_POS = UINT32_MAX; class FactorizationGroup { diff --git a/src/include/planner/query_planner.h b/src/include/planner/query_planner.h index e5bd2049666..3c970896ea7 100644 --- a/src/include/planner/query_planner.h +++ b/src/include/planner/query_planner.h @@ -60,18 +60,8 @@ class QueryPlanner { static void appendUnwind(BoundUnwindClause& boundUnwindClause, LogicalPlan& plan); - static void appendFlattens(const std::unordered_set& groupsPos, LogicalPlan& plan); - // return position of the only unFlat group - // or position of any flat group if there is no unFlat group. - static uint32_t appendFlattensButOne( - const std::unordered_set& groupsPos, LogicalPlan& plan); - static void appendFlattenIfNecessary( - const std::shared_ptr& expression, LogicalPlan& plan); - static inline void appendFlattenIfNecessary(uint32_t groupPos, LogicalPlan& plan) { - auto expressions = plan.getSchema()->getExpressionsInScope(groupPos); - assert(!expressions.empty()); - appendFlattenIfNecessary(expressions[0], plan); - } + static void appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan); + static void appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan); void appendFilter(const std::shared_ptr& expression, LogicalPlan& plan); diff --git a/src/include/planner/update_planner.h b/src/include/planner/update_planner.h index b8f1201daca..6d8024cdfce 100644 --- a/src/include/planner/update_planner.h +++ b/src/include/planner/update_planner.h @@ -43,8 +43,6 @@ class UpdatePlanner { LogicalPlan& plan); void appendDeleteRel( const std::vector>& deleteRels, LogicalPlan& plan); - - void flattenRel(const binder::RelExpression& rel, LogicalPlan& plan); }; } // namespace planner diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt index fe77a954e69..b8849a1b78f 100644 --- a/src/optimizer/CMakeLists.txt +++ b/src/optimizer/CMakeLists.txt @@ -1,7 +1,9 @@ add_library(kuzu_optimizer OBJECT + factorization_rewriter.cpp index_nested_loop_join_optimizer.cpp - optimizer.cpp) + optimizer.cpp + remove_factorization_rewriter.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/optimizer/factorization_rewriter.cpp b/src/optimizer/factorization_rewriter.cpp new file mode 100644 index 00000000000..5d1de8b1272 --- /dev/null +++ b/src/optimizer/factorization_rewriter.cpp @@ -0,0 +1,253 @@ +#include "optimizer/factorization_rewriter.h" + +#include "planner/logical_plan/logical_operator/flatten_resolver.h" +#include "planner/logical_plan/logical_operator/logical_aggregate.h" +#include "planner/logical_plan/logical_operator/logical_create.h" +#include "planner/logical_plan/logical_operator/logical_delete.h" +#include "planner/logical_plan/logical_operator/logical_distinct.h" +#include "planner/logical_plan/logical_operator/logical_extend.h" +#include "planner/logical_plan/logical_operator/logical_filter.h" +#include "planner/logical_plan/logical_operator/logical_flatten.h" +#include "planner/logical_plan/logical_operator/logical_hash_join.h" +#include "planner/logical_plan/logical_operator/logical_intersect.h" +#include "planner/logical_plan/logical_operator/logical_limit.h" +#include "planner/logical_plan/logical_operator/logical_order_by.h" +#include "planner/logical_plan/logical_operator/logical_projection.h" +#include "planner/logical_plan/logical_operator/logical_set.h" +#include "planner/logical_plan/logical_operator/logical_skip.h" +#include "planner/logical_plan/logical_operator/logical_union.h" +#include "planner/logical_plan/logical_operator/logical_unwind.h" + +using namespace kuzu::planner; + +namespace kuzu { +namespace optimizer { + +void FactorizationRewriter::rewrite(planner::LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void FactorizationRewriter::visitOperator(planner::LogicalOperator* op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + switch (op->getOperatorType()) { + case LogicalOperatorType::EXTEND: { + visitExtend(op); + } break; + case LogicalOperatorType::HASH_JOIN: { + visitHashJoin(op); + } break; + case LogicalOperatorType::INTERSECT: { + visitIntersect(op); + } break; + case LogicalOperatorType::PROJECTION: { + visitProjection(op); + } break; + case LogicalOperatorType::AGGREGATE: { + visitAggregate(op); + } break; + case LogicalOperatorType::ORDER_BY: { + visitOrderBy(op); + } break; + case LogicalOperatorType::SKIP: { + visitSkip(op); + } break; + case LogicalOperatorType::LIMIT: { + visitLimit(op); + } break; + case LogicalOperatorType::DISTINCT: { + visitDistinct(op); + } break; + case LogicalOperatorType::UNWIND: { + visitUnwind(op); + } break; + case LogicalOperatorType::UNION_ALL: { + visitUnion(op); + } break; + case LogicalOperatorType::FILTER: { + visitFilter(op); + } break; + case LogicalOperatorType::SET_NODE_PROPERTY: { + visitSetNodeProperty(op); + } break; + case LogicalOperatorType::SET_REL_PROPERTY: { + visitSetRelProperty(op); + } break; + case LogicalOperatorType::DELETE_REL: { + visitDeleteRel(op); + } break; + case LogicalOperatorType::CREATE_NODE: { + visitCreateNode(op); + } break; + case LogicalOperatorType::CREATE_REL: { + visitCreateRel(op); + } break; + default: + break; + } + op->computeSchema(); +} + +void FactorizationRewriter::visitExtend(planner::LogicalOperator* op) { + auto extend = (LogicalExtend*)op; + auto groupsPosToFlatten = extend->getGroupsPosToFlatten(); + extend->setChild(0, appendFlattens(extend->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitHashJoin(planner::LogicalOperator* op) { + auto hashJoin = (LogicalHashJoin*)op; + auto groupsPosToFlattenOnProbeSide = hashJoin->getGroupsPosToFlattenOnProbeSide(); + hashJoin->setChild(0, appendFlattens(hashJoin->getChild(0), groupsPosToFlattenOnProbeSide)); + auto groupsPosToFlattenOnBuildSide = hashJoin->getGroupsPosToFlattenOnBuildSide(); + hashJoin->setChild(1, appendFlattens(hashJoin->getChild(1), groupsPosToFlattenOnBuildSide)); +} + +void FactorizationRewriter::visitIntersect(planner::LogicalOperator* op) { + auto intersect = (LogicalIntersect*)op; + auto groupsPosToFlattenOnProbeSide = intersect->getGroupsPosToFlattenOnProbeSide(); + intersect->setChild(0, appendFlattens(intersect->getChild(0), groupsPosToFlattenOnProbeSide)); + for (auto i = 0u; i < intersect->getNumBuilds(); ++i) { + auto groupPosToFlatten = intersect->getGroupsPosToFlattenOnBuildSide(i); + auto childIdx = i + 1; // skip probe + intersect->setChild( + childIdx, appendFlattens(intersect->getChild(childIdx), groupPosToFlatten)); + } +} + +void FactorizationRewriter::visitProjection(planner::LogicalOperator* op) { + auto projection = (LogicalProjection*)op; + for (auto& expression : projection->getExpressionsToProject()) { + auto dependentGroupsPos = op->getChild(0)->getSchema()->getDependentGroupsPos(expression); + auto groupsPosToFlatten = factorization::FlattenAllButOne::getGroupsPosToFlatten( + dependentGroupsPos, op->getChild(0)->getSchema()); + projection->setChild(0, appendFlattens(projection->getChild(0), groupsPosToFlatten)); + } +} + +void FactorizationRewriter::visitAggregate(planner::LogicalOperator* op) { + auto aggregate = (LogicalAggregate*)op; + auto groupsPosToFlattenForGroupBy = aggregate->getGroupsPosToFlattenForGroupBy(); + aggregate->setChild(0, appendFlattens(aggregate->getChild(0), groupsPosToFlattenForGroupBy)); + auto groupsPosToFlattenForAggregate = aggregate->getGroupsPosToFlattenForAggregate(); + aggregate->setChild(0, appendFlattens(aggregate->getChild(0), groupsPosToFlattenForAggregate)); +} + +void FactorizationRewriter::visitOrderBy(planner::LogicalOperator* op) { + auto orderBy = (LogicalOrderBy*)op; + auto groupsPosToFlatten = orderBy->getGroupsPosToFlatten(); + orderBy->setChild(0, appendFlattens(orderBy->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitSkip(planner::LogicalOperator* op) { + auto skip = (LogicalSkip*)op; + auto groupsPosToFlatten = skip->getGroupsPosToFlatten(); + skip->setChild(0, appendFlattens(skip->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitLimit(planner::LogicalOperator* op) { + auto limit = (LogicalLimit*)op; + auto groupsPosToFlatten = limit->getGroupsPosToFlatten(); + limit->setChild(0, appendFlattens(limit->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitDistinct(planner::LogicalOperator* op) { + auto distinct = (LogicalDistinct*)op; + auto groupsPosToFlatten = distinct->getGroupsPosToFlatten(); + distinct->setChild(0, appendFlattens(distinct->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitUnwind(planner::LogicalOperator* op) { + auto unwind = (LogicalUnwind*)op; + auto groupsPosToFlatten = unwind->getGroupsPosToFlatten(); + unwind->setChild(0, appendFlattens(unwind->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitUnion(planner::LogicalOperator* op) { + auto union_ = (LogicalUnion*)op; + for (auto i = 0u; i < union_->getNumChildren(); ++i) { + auto groupsPosToFlatten = union_->getGroupsPosToFlatten(i); + union_->setChild(i, appendFlattens(union_->getChild(i), groupsPosToFlatten)); + } +} + +void FactorizationRewriter::visitFilter(planner::LogicalOperator* op) { + auto filter = (LogicalFilter*)op; + auto groupsPosToFlatten = filter->getGroupsPosToFlatten(); + filter->setChild(0, appendFlattens(filter->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitSetNodeProperty(planner::LogicalOperator* op) { + auto setNodeProperty = (LogicalSetNodeProperty*)op; + for (auto i = 0u; i < setNodeProperty->getNumNodes(); ++i) { + auto lhsNodeID = setNodeProperty->getNode(i)->getInternalIDProperty(); + auto rhs = setNodeProperty->getSetItem(i).second; + // flatten rhs + auto rhsDependentGroupsPos = op->getChild(0)->getSchema()->getDependentGroupsPos(rhs); + auto rhsGroupsPosToFlatten = factorization::FlattenAllButOne::getGroupsPosToFlatten( + rhsDependentGroupsPos, op->getChild(0)->getSchema()); + setNodeProperty->setChild( + 0, appendFlattens(setNodeProperty->getChild(0), rhsGroupsPosToFlatten)); + // flatten lhs if needed + auto lhsGroupPos = op->getChild(0)->getSchema()->getGroupPos(*lhsNodeID); + auto rhsLeadingGroupPos = + SchemaUtils::getLeadingGroupPos(rhsDependentGroupsPos, *op->getChild(0)->getSchema()); + if (lhsGroupPos != rhsLeadingGroupPos) { + setNodeProperty->setChild( + 0, appendFlattenIfNecessary(setNodeProperty->getChild(0), lhsGroupPos)); + } + } +} + +void FactorizationRewriter::visitSetRelProperty(planner::LogicalOperator* op) { + auto setRelProperty = (LogicalSetRelProperty*)op; + for (auto i = 0u; i < setRelProperty->getNumRels(); ++i) { + auto groupsPosToFlatten = setRelProperty->getGroupsPosToFlatten(i); + setRelProperty->setChild( + 0, appendFlattens(setRelProperty->getChild(0), groupsPosToFlatten)); + } +} + +void FactorizationRewriter::visitDeleteRel(planner::LogicalOperator* op) { + auto deleteRel = (LogicalDeleteRel*)op; + for (auto i = 0u; i < deleteRel->getNumRels(); ++i) { + auto groupsPosToFlatten = deleteRel->getGroupsPosToFlatten(i); + deleteRel->setChild(0, appendFlattens(deleteRel->getChild(0), groupsPosToFlatten)); + } +} + +void FactorizationRewriter::visitCreateNode(planner::LogicalOperator* op) { + auto createNode = (LogicalCreateNode*)op; + auto groupsPosToFlatten = createNode->getGroupsPosToFlatten(); + createNode->setChild(0, appendFlattens(createNode->getChild(0), groupsPosToFlatten)); +} + +void FactorizationRewriter::visitCreateRel(planner::LogicalOperator* op) { + auto createRel = (LogicalCreateRel*)op; + auto groupsPosToFlatten = createRel->getGroupsPosToFlatten(); + createRel->setChild(0, appendFlattens(createRel->getChild(0), groupsPosToFlatten)); +} + +std::shared_ptr FactorizationRewriter::appendFlattens( + std::shared_ptr op, + const std::unordered_set& groupsPos) { + auto currentChild = std::move(op); + for (auto groupPos : groupsPos) { + currentChild = appendFlattenIfNecessary(std::move(currentChild), groupPos); + } + return currentChild; +} + +std::shared_ptr FactorizationRewriter::appendFlattenIfNecessary( + std::shared_ptr op, planner::f_group_pos groupPos) { + if (op->getSchema()->getGroup(groupPos)->isFlat()) { + return op; + } + auto flatten = std::make_shared(groupPos, std::move(op)); + flatten->computeSchema(); + return flatten; +} + +} // namespace optimizer +} // namespace kuzu diff --git a/src/optimizer/index_nested_loop_join_optimizer.cpp b/src/optimizer/index_nested_loop_join_optimizer.cpp index 7e89a8d07ca..6284f1f547b 100644 --- a/src/optimizer/index_nested_loop_join_optimizer.cpp +++ b/src/optimizer/index_nested_loop_join_optimizer.cpp @@ -18,7 +18,6 @@ std::shared_ptr IndexNestedLoopJoinOptimizer::rewrite( for (auto i = 0u; i < op->getNumChildren(); ++i) { op->setChild(i, rewrite(op->getChild(i))); } - op->computeSchema(); return op; } @@ -50,11 +49,6 @@ std::shared_ptr IndexNestedLoopJoinOptimizer::rewriteFilter( return op; } auto currentOp = op->getChild(0); - // Match flatten. TODO(Xiyang): remove flatten as a logical operator. - if (currentOp->getOperatorType() != LogicalOperatorType::FLATTEN) { - return op; - } - currentOp = currentOp->getChild(0); // Match cross product if (currentOp->getOperatorType() != LogicalOperatorType::CROSS_PRODUCT) { return op; @@ -92,7 +86,6 @@ std::shared_ptr IndexNestedLoopJoinOptimizer::rewriteCrossProdu rightOp = rightOp->getChild(0); } auto newRoot = op->getChild(1); - newRoot->computeSchemaRecursive(); return newRoot; } diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 80005cf8c3c..3d48d9d53cd 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -1,12 +1,20 @@ #include "optimizer/optimizer.h" +#include "optimizer/factorization_rewriter.h" #include "optimizer/index_nested_loop_join_optimizer.h" +#include "optimizer/remove_factorization_rewriter.h" namespace kuzu { namespace optimizer { void Optimizer::optimize(planner::LogicalPlan* plan) { + auto removeFactorizationRewriter = RemoveFactorizationRewriter(); + removeFactorizationRewriter.rewrite(plan); + IndexNestedLoopJoinOptimizer::rewrite(plan->getLastOperator()); + + auto factorizationRewriter = FactorizationRewriter(); + factorizationRewriter.rewrite(plan); } } // namespace optimizer diff --git a/src/optimizer/remove_factorization_rewriter.cpp b/src/optimizer/remove_factorization_rewriter.cpp new file mode 100644 index 00000000000..8a19e7ce614 --- /dev/null +++ b/src/optimizer/remove_factorization_rewriter.cpp @@ -0,0 +1,43 @@ +#include "optimizer/remove_factorization_rewriter.h" + +#include "common/exception.h" + +using namespace kuzu::planner; + +namespace kuzu { +namespace optimizer { + +void RemoveFactorizationRewriter::rewrite(planner::LogicalPlan* plan) { + auto root = plan->getLastOperator(); + rewriteOperator(root); + if (subPlanHasFlatten(root.get())) { + throw common::InternalException("Remove factorization rewriter failed."); + } +} + +std::shared_ptr RemoveFactorizationRewriter::rewriteOperator( + std::shared_ptr op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + op->setChild(i, rewriteOperator(op->getChild(i))); + } + if (op->getOperatorType() == planner::LogicalOperatorType::FLATTEN) { + return op->getChild(0); + } + return op; +} + +bool RemoveFactorizationRewriter::subPlanHasFlatten(planner::LogicalOperator* op) { + if (op->getOperatorType() == planner::LogicalOperatorType::FLATTEN) { + return true; + } + for (auto& child : op->getChildren()) { + if (subPlanHasFlatten(child.get())) { + return true; + } + } + return false; +} + +} // namespace optimizer +} // namespace kuzu diff --git a/src/planner/join_order_enumerator.cpp b/src/planner/join_order_enumerator.cpp index 9b579aa24b4..62644ac4d84 100644 --- a/src/planner/join_order_enumerator.cpp +++ b/src/planner/join_order_enumerator.cpp @@ -338,13 +338,13 @@ void JoinOrderEnumerator::planWCOJoin(const SubqueryGraph& subgraph, auto newSubgraph = subgraph; std::vector prevSubgraphs; prevSubgraphs.push_back(subgraph); - std::vector> boundNodes; + expression_vector boundNodeIDs; std::vector> relPlans; for (auto& rel : rels) { auto boundNode = rel->getSrcNodeName() == intersectNode->getUniqueName() ? rel->getDstNode() : rel->getSrcNode(); - boundNodes.push_back(boundNode); + boundNodeIDs.push_back(boundNode->getInternalIDProperty()); auto relPos = context->getQueryGraph()->getQueryRelPos(rel->getUniqueName()); auto prevSubgraph = context->getEmptySubqueryGraph(); prevSubgraph.addQueryRel(relPos); @@ -366,7 +366,8 @@ void JoinOrderEnumerator::planWCOJoin(const SubqueryGraph& subgraph, for (auto& relPlan : relPlans) { rightPlansCopy.push_back(relPlan->shallowCopy()); } - appendIntersect(intersectNode, boundNodes, *leftPlanCopy, rightPlansCopy); + appendIntersect( + intersectNode->getInternalIDProperty(), boundNodeIDs, *leftPlanCopy, rightPlansCopy); for (auto& predicate : predicates) { queryPlanner->appendFilter(predicate, *leftPlanCopy); } @@ -544,6 +545,7 @@ void JoinOrderEnumerator::appendIndexScanNode(std::shared_ptr& n plan.setLastOperator(std::move(scan)); } +// When extend might increase cardinality (i.e. n * m), we extend to a new factorization group. bool JoinOrderEnumerator::needExtendToNewGroup( RelExpression& rel, NodeExpression& boundNode, RelDirection direction) { auto extendToNewGroup = false; @@ -557,22 +559,14 @@ bool JoinOrderEnumerator::needExtendToNewGroup( return extendToNewGroup; } -bool JoinOrderEnumerator::needFlatInput( - RelExpression& rel, NodeExpression& boundNode, RelDirection direction) { - auto needFlatInput = needExtendToNewGroup(rel, boundNode, direction); - needFlatInput |= rel.isVariableLength(); - return needFlatInput; -} - void JoinOrderEnumerator::appendExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, RelDirection direction, const expression_vector& properties, LogicalPlan& plan) { auto extendToNewGroup = needExtendToNewGroup(*rel, *boundNode, direction); - if (needFlatInput(*rel, *boundNode, direction)) { - QueryPlanner::appendFlattenIfNecessary(boundNode->getInternalIDProperty(), plan); - } auto extend = make_shared( boundNode, nbrNode, rel, direction, properties, extendToNewGroup, plan.getLastOperator()); + QueryPlanner::appendFlattens(extend->getGroupsPosToFlatten(), plan); + extend->setChild(0, plan.getLastOperator()); extend->computeSchema(); plan.setLastOperator(std::move(extend)); // update cardinality estimation info @@ -608,64 +602,21 @@ void JoinOrderEnumerator::planJoin(const expression_vector& joinNodeIDs, JoinTyp } } -static bool isJoinKeyUniqueOnBuildSide(const Expression& joinNodeID, LogicalPlan& buildPlan) { - auto buildSchema = buildPlan.getSchema(); - auto numGroupsInScope = buildSchema->getGroupsPosInScope().size(); - bool hasProjectedOutGroups = buildSchema->getNumGroups() > numGroupsInScope; - if (numGroupsInScope > 1 || hasProjectedOutGroups) { - return false; - } - // Now there is a single factorization group, we need to further make sure joinNodeID comes from - // ScanNodeID operator. Because if joinNodeID comes from a ColExtend we cannot guarantee the - // reverse mapping is still many-to-one. We look for the most simple pattern where build plan is - // linear. - auto firstop = buildPlan.getLastOperator().get(); - while (firstop->getNumChildren() != 0) { - if (firstop->getNumChildren() > 1) { - return false; - } - firstop = firstop->getChild(0).get(); - } - if (firstop->getOperatorType() != LogicalOperatorType::SCAN_NODE) { - return false; - } - auto scanNodeID = (LogicalScanNode*)firstop; - if (scanNodeID->getNode()->getInternalIDPropertyName() != joinNodeID.getUniqueName()) { - return false; - } - return true; -} - void JoinOrderEnumerator::appendHashJoin(const expression_vector& joinNodeIDs, JoinType joinType, bool isProbeAcc, LogicalPlan& probePlan, LogicalPlan& buildPlan) { - probePlan.increaseCost(probePlan.getCardinality() + buildPlan.getCardinality()); - // Flat probe side key group in either of the following two cases: - // 1. there are multiple join nodes; - // 2. if the build side contains more than one group or the build side has projected out data - // chunks, which may increase the multiplicity of data chunks in the build side. The key is to - // keep probe side key unflat only when we know that there is only 0 or 1 match for each key. - // TODO(Guodong): when the build side has only flat payloads, we should consider getting rid of - // flattening probe key, instead duplicating keys as in vectorized processing if necessary. - auto needFlattenProbeJoinKey = false; - needFlattenProbeJoinKey |= joinNodeIDs.size() > 1; - needFlattenProbeJoinKey |= !isJoinKeyUniqueOnBuildSide(*joinNodeIDs[0], buildPlan); - if (needFlattenProbeJoinKey) { - for (auto& joinNodeID : joinNodeIDs) { - auto probeSideKeyGroupPos = probePlan.getSchema()->getGroupPos(*joinNodeID); - QueryPlanner::appendFlattenIfNecessary(probeSideKeyGroupPos, probePlan); - } - } - // Flat all but one build side key groups. - std::unordered_set joinNodesGroupPos; - for (auto& joinNodeID : joinNodeIDs) { - joinNodesGroupPos.insert(buildPlan.getSchema()->getGroupPos(*joinNodeID)); - } - QueryPlanner::appendFlattensButOne(joinNodesGroupPos, buildPlan); auto hashJoin = make_shared(joinNodeIDs, joinType, isProbeAcc, buildPlan.getSchema()->getExpressionsInScope(), probePlan.getLastOperator(), buildPlan.getLastOperator()); + // Apply flattening to probe side + auto groupsPosToFlattenOnProbeSide = hashJoin->getGroupsPosToFlattenOnProbeSide(); + QueryPlanner::appendFlattens(groupsPosToFlattenOnProbeSide, probePlan); + hashJoin->setChild(0, probePlan.getLastOperator()); + // Apply flattening to build side + QueryPlanner::appendFlattens(hashJoin->getGroupsPosToFlattenOnBuildSide(), buildPlan); + hashJoin->setChild(1, buildPlan.getLastOperator()); hashJoin->computeSchema(); - if (needFlattenProbeJoinKey) { + probePlan.increaseCost(probePlan.getCardinality() + buildPlan.getCardinality()); + if (!groupsPosToFlattenOnProbeSide.empty()) { probePlan.multiplyCardinality( buildPlan.getCardinality() * EnumeratorKnobs::PREDICATE_SELECTIVITY); probePlan.multiplyCost(EnumeratorKnobs::FLAT_PROBE_PENALTY); @@ -676,54 +627,44 @@ void JoinOrderEnumerator::appendHashJoin(const expression_vector& joinNodeIDs, J void JoinOrderEnumerator::appendMarkJoin(const expression_vector& joinNodeIDs, const std::shared_ptr& mark, bool isProbeAcc, LogicalPlan& probePlan, LogicalPlan& buildPlan) { - // Apply flattening to probe side - std::unordered_set joinNodeGroupsPosInProbeSide; - auto needFlattenProbeJoinKey = false; - needFlattenProbeJoinKey |= joinNodeIDs.size() > 1; - needFlattenProbeJoinKey |= !isJoinKeyUniqueOnBuildSide(*joinNodeIDs[0], buildPlan); - for (auto& joinNodeID : joinNodeIDs) { - auto probeKeyGroupPos = probePlan.getSchema()->getGroupPos(*joinNodeID); - if (needFlattenProbeJoinKey) { - QueryPlanner::appendFlattenIfNecessary(probeKeyGroupPos, probePlan); - } - joinNodeGroupsPosInProbeSide.insert(probeKeyGroupPos); - } - // Apply flattening to build side - std::unordered_set joinNodeGroupsPosInBuildSide; - for (auto& joinNodeID : joinNodeIDs) { - joinNodeGroupsPosInBuildSide.insert(buildPlan.getSchema()->getGroupPos(*joinNodeID)); - } - QueryPlanner::appendFlattensButOne(joinNodeGroupsPosInBuildSide, buildPlan); - probePlan.increaseCost(probePlan.getCardinality() + buildPlan.getCardinality()); auto hashJoin = make_shared( joinNodeIDs, mark, isProbeAcc, probePlan.getLastOperator(), buildPlan.getLastOperator()); + // Apply flattening to probe side + QueryPlanner::appendFlattens(hashJoin->getGroupsPosToFlattenOnProbeSide(), probePlan); + hashJoin->setChild(0, probePlan.getLastOperator()); + // Apply flattening to build side + QueryPlanner::appendFlattens(hashJoin->getGroupsPosToFlattenOnBuildSide(), buildPlan); + hashJoin->setChild(1, buildPlan.getLastOperator()); hashJoin->computeSchema(); + probePlan.increaseCost(probePlan.getCardinality() + buildPlan.getCardinality()); probePlan.setLastOperator(std::move(hashJoin)); } -void JoinOrderEnumerator::appendIntersect(const std::shared_ptr& intersectNode, - std::vector>& boundNodes, LogicalPlan& probePlan, +void JoinOrderEnumerator::appendIntersect(const std::shared_ptr& intersectNodeID, + binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan, std::vector>& buildPlans) { - auto intersectNodeID = intersectNode->getInternalIDProperty(); - assert(boundNodes.size() == buildPlans.size()); + assert(boundNodeIDs.size() == buildPlans.size()); std::vector> buildChildren; std::vector> buildInfos; for (auto i = 0u; i < buildPlans.size(); ++i) { - auto boundNodeID = boundNodes[i]->getInternalIDProperty(); - QueryPlanner::appendFlattenIfNecessary( - probePlan.getSchema()->getGroupPos(*boundNodeID), probePlan); + auto boundNodeID = boundNodeIDs[i]; auto buildPlan = buildPlans[i].get(); - auto buildSchema = buildPlan->getSchema(); - QueryPlanner::appendFlattenIfNecessary(buildSchema->getGroupPos(*boundNodeID), *buildPlan); - auto expressions = buildSchema->getExpressionsInScope(); - auto buildInfo = std::make_unique(boundNodeID, expressions); + auto buildInfo = std::make_unique( + boundNodeID, buildPlan->getSchema()->getExpressionsInScope()); buildChildren.push_back(buildPlan->getLastOperator()); buildInfos.push_back(std::move(buildInfo)); } - auto logicalIntersect = make_shared(intersectNodeID, - probePlan.getLastOperator(), std::move(buildChildren), std::move(buildInfos)); - logicalIntersect->computeSchema(); - probePlan.setLastOperator(std::move(logicalIntersect)); + auto intersect = make_shared(intersectNodeID, probePlan.getLastOperator(), + std::move(buildChildren), std::move(buildInfos)); + QueryPlanner::appendFlattens(intersect->getGroupsPosToFlattenOnProbeSide(), probePlan); + intersect->setChild(0, probePlan.getLastOperator()); + for (auto i = 0u; i < buildPlans.size(); ++i) { + QueryPlanner::appendFlattens( + intersect->getGroupsPosToFlattenOnBuildSide(i), *buildPlans[i]); + intersect->setChild(i + 1, buildPlans[i]->getLastOperator()); + } + intersect->computeSchema(); + probePlan.setLastOperator(std::move(intersect)); } void JoinOrderEnumerator::appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) { diff --git a/src/planner/operator/CMakeLists.txt b/src/planner/operator/CMakeLists.txt index 986e323b2d6..1cf3e690909 100644 --- a/src/planner/operator/CMakeLists.txt +++ b/src/planner/operator/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(kuzu_planner_operator OBJECT base_logical_operator.cpp + flatten_resolver.cpp logical_accumulate.cpp logical_aggregate.cpp logical_create.cpp @@ -20,6 +21,7 @@ add_library(kuzu_planner_operator logical_projection.cpp logical_scan_node.cpp logical_scan_node_property.cpp + logical_set.cpp logical_skip.cpp logical_union.cpp logical_unwind.cpp diff --git a/src/planner/operator/flatten_resolver.cpp b/src/planner/operator/flatten_resolver.cpp new file mode 100644 index 00000000000..ab57ba728d1 --- /dev/null +++ b/src/planner/operator/flatten_resolver.cpp @@ -0,0 +1,36 @@ +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + +namespace kuzu { +namespace planner { +namespace factorization { + +f_group_pos_set FlattenAllButOne::getGroupsPosToFlatten( + const f_group_pos_set& groupsPos, Schema* schema) { + std::vector unFlatGroupsPos; + for (auto groupPos : groupsPos) { + if (!schema->getGroup(groupPos)->isFlat()) { + unFlatGroupsPos.push_back(groupPos); + } + } + f_group_pos_set result; + // Keep the first group as unFlat. + for (auto i = 1u; i < unFlatGroupsPos.size(); ++i) { + result.insert(unFlatGroupsPos[i]); + } + return result; +} + +f_group_pos_set FlattenAll::getGroupsPosToFlatten( + const f_group_pos_set& groupsPos, Schema* schema) { + f_group_pos_set result; + for (auto groupPos : groupsPos) { + if (!schema->getGroup(groupPos)->isFlat()) { + result.insert(groupPos); + } + } + return result; +} + +} // namespace factorization +} // namespace planner +} // namespace kuzu diff --git a/src/planner/operator/logical_aggregate.cpp b/src/planner/operator/logical_aggregate.cpp index e8d2d29d86e..7e035591025 100644 --- a/src/planner/operator/logical_aggregate.cpp +++ b/src/planner/operator/logical_aggregate.cpp @@ -1,8 +1,41 @@ #include "planner/logical_plan/logical_operator/logical_aggregate.h" +#include "binder/expression/function_expression.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + namespace kuzu { namespace planner { +using namespace factorization; + +f_group_pos_set LogicalAggregate::getGroupsPosToFlattenForGroupBy() { + f_group_pos_set dependentGroupsPos; + for (auto& expression : expressionsToGroupBy) { + for (auto groupPos : children[0]->getSchema()->getDependentGroupsPos(expression)) { + dependentGroupsPos.insert(groupPos); + } + } + if (hasDistinctAggregate()) { + return FlattenAll::getGroupsPosToFlatten(dependentGroupsPos, children[0]->getSchema()); + } else { + return FlattenAllButOne::getGroupsPosToFlatten( + dependentGroupsPos, children[0]->getSchema()); + } +} + +f_group_pos_set LogicalAggregate::getGroupsPosToFlattenForAggregate() { + if (hasDistinctAggregate() || expressionsToAggregate.size() > 1) { + f_group_pos_set dependentGroupsPos; + for (auto& expression : expressionsToAggregate) { + for (auto groupPos : children[0]->getSchema()->getDependentGroupsPos(expression)) { + dependentGroupsPos.insert(groupPos); + } + } + return FlattenAll::getGroupsPosToFlatten(dependentGroupsPos, children[0]->getSchema()); + } + return f_group_pos_set{}; +} + void LogicalAggregate::computeSchema() { createEmptySchema(); auto groupPos = schema->createGroup(); @@ -27,5 +60,15 @@ std::string LogicalAggregate::getExpressionsForPrinting() const { return result; } +bool LogicalAggregate::hasDistinctAggregate() { + for (auto& expressionToAggregate : expressionsToAggregate) { + auto& functionExpression = (binder::AggregateFunctionExpression&)*expressionToAggregate; + if (functionExpression.isDistinct()) { + return true; + } + } + return false; +} + } // namespace planner } // namespace kuzu diff --git a/src/planner/operator/logical_distinct.cpp b/src/planner/operator/logical_distinct.cpp index fd01477313f..9d813864b1f 100644 --- a/src/planner/operator/logical_distinct.cpp +++ b/src/planner/operator/logical_distinct.cpp @@ -1,8 +1,21 @@ #include "planner/logical_plan/logical_operator/logical_distinct.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + namespace kuzu { namespace planner { +f_group_pos_set LogicalDistinct::getGroupsPosToFlatten() { + f_group_pos_set dependentGroupsPos; + auto childSchema = children[0]->getSchema(); + for (auto& expression : expressionsToDistinct) { + for (auto groupPos : childSchema->getDependentGroupsPos(expression)) { + dependentGroupsPos.insert(groupPos); + } + } + return factorization::FlattenAll::getGroupsPosToFlatten(dependentGroupsPos, childSchema); +} + std::string LogicalDistinct::getExpressionsForPrinting() const { std::string result; for (auto& expression : expressionsToDistinct) { diff --git a/src/planner/operator/logical_extend.cpp b/src/planner/operator/logical_extend.cpp index 2fd2a847f16..345ab518d84 100644 --- a/src/planner/operator/logical_extend.cpp +++ b/src/planner/operator/logical_extend.cpp @@ -1,8 +1,23 @@ #include "planner/logical_plan/logical_operator/logical_extend.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + namespace kuzu { namespace planner { +f_group_pos_set LogicalExtend::getGroupsPosToFlatten() { + f_group_pos_set result; + auto requireFlatBoundNode = extendToNewGroup || rel->isVariableLength(); + if (requireFlatBoundNode) { + auto inSchema = children[0]->getSchema(); + auto boundNodeGroupPos = inSchema->getGroupPos(*boundNode->getInternalIDProperty()); + if (!inSchema->getGroup(boundNodeGroupPos)->isFlat()) { + result.insert(boundNodeGroupPos); + } + } + return result; +} + void LogicalExtend::computeSchema() { copyChildSchema(0); auto boundGroupPos = schema->getGroupPos(boundNode->getInternalIDPropertyName()); diff --git a/src/planner/operator/logical_filter.cpp b/src/planner/operator/logical_filter.cpp index 54afb301bef..ef5f1ba0252 100644 --- a/src/planner/operator/logical_filter.cpp +++ b/src/planner/operator/logical_filter.cpp @@ -1,8 +1,16 @@ #include "planner/logical_plan/logical_operator/logical_filter.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + namespace kuzu { namespace planner { +f_group_pos_set LogicalFilter::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + auto dependentGroupsPos = childSchema->getDependentGroupsPos(expression); + return factorization::FlattenAllButOne::getGroupsPosToFlatten(dependentGroupsPos, childSchema); +} + f_group_pos LogicalFilter::getGroupPosToSelect() const { auto childSchema = children[0]->getSchema(); auto dependentGroupsPos = childSchema->getDependentGroupsPos(expression); diff --git a/src/planner/operator/logical_flatten.cpp b/src/planner/operator/logical_flatten.cpp index e89b584c766..3f808312349 100644 --- a/src/planner/operator/logical_flatten.cpp +++ b/src/planner/operator/logical_flatten.cpp @@ -5,7 +5,6 @@ namespace planner { void LogicalFlatten::computeSchema() { copyChildSchema(0); - auto groupPos = schema->getGroupPos(expression->getUniqueName()); schema->flattenGroup(groupPos); } diff --git a/src/planner/operator/logical_hash_join.cpp b/src/planner/operator/logical_hash_join.cpp index 853977be1e6..b1352d89cb6 100644 --- a/src/planner/operator/logical_hash_join.cpp +++ b/src/planner/operator/logical_hash_join.cpp @@ -1,10 +1,33 @@ #include "planner/logical_plan/logical_operator/logical_hash_join.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" +#include "planner/logical_plan/logical_operator/logical_scan_node.h" #include "planner/logical_plan/logical_operator/sink_util.h" namespace kuzu { namespace planner { +f_group_pos_set LogicalHashJoin::getGroupsPosToFlattenOnProbeSide() { + f_group_pos_set result; + if (!requireFlatProbeKeys()) { + return result; + } + auto probeSchema = children[0]->getSchema(); + for (auto& joinNodeID : joinNodeIDs) { + result.insert(probeSchema->getGroupPos(*joinNodeID)); + } + return result; +} + +f_group_pos_set LogicalHashJoin::getGroupsPosToFlattenOnBuildSide() { + auto buildSchema = children[1]->getSchema(); + f_group_pos_set joinNodesGroupPos; + for (auto& joinNodeID : joinNodeIDs) { + joinNodesGroupPos.insert(buildSchema->getGroupPos(*joinNodeID)); + } + return factorization::FlattenAllButOne::getGroupsPosToFlatten(joinNodesGroupPos, buildSchema); +} + void LogicalHashJoin::computeSchema() { auto probeSchema = children[0]->getSchema(); auto buildSchema = children[1]->getSchema(); @@ -56,5 +79,41 @@ void LogicalHashJoin::computeSchema() { } } +bool LogicalHashJoin::requireFlatProbeKeys() { + if (joinNodeIDs.size() > 1) { + return true; + } + auto joinNodeID = joinNodeIDs[0].get(); + return !isJoinKeyUniqueOnBuildSide(*joinNodeID); +} + +bool LogicalHashJoin::isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID) { + auto buildSchema = children[1]->getSchema(); + auto numGroupsInScope = buildSchema->getGroupsPosInScope().size(); + bool hasProjectedOutGroups = buildSchema->getNumGroups() > numGroupsInScope; + if (numGroupsInScope > 1 || hasProjectedOutGroups) { + return false; + } + // Now there is a single factorization group, we need to further make sure joinNodeID comes from + // ScanNodeID operator. Because if joinNodeID comes from a ColExtend we cannot guarantee the + // reverse mapping is still many-to-one. We look for the most simple pattern where build plan is + // linear. + auto op = children[1].get(); + while (op->getNumChildren() != 0) { + if (op->getNumChildren() > 1) { + return false; + } + op = op->getChild(0).get(); + } + if (op->getOperatorType() != LogicalOperatorType::SCAN_NODE) { + return false; + } + auto scanNodeID = (LogicalScanNode*)op; + if (scanNodeID->getNode()->getInternalIDPropertyName() != joinNodeID.getUniqueName()) { + return false; + } + return true; +} + } // namespace planner } // namespace kuzu diff --git a/src/planner/operator/logical_intersect.cpp b/src/planner/operator/logical_intersect.cpp index 1cbd6de750f..d4c5e41ed12 100644 --- a/src/planner/operator/logical_intersect.cpp +++ b/src/planner/operator/logical_intersect.cpp @@ -3,6 +3,21 @@ namespace kuzu { namespace planner { +f_group_pos_set LogicalIntersect::getGroupsPosToFlattenOnProbeSide() { + f_group_pos_set result; + for (auto& buildInfo : buildInfos) { + result.insert(children[0]->getSchema()->getGroupPos(*buildInfo->keyNodeID)); + } + return result; +} + +f_group_pos_set LogicalIntersect::getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx) { + f_group_pos_set result; + auto childIdx = buildIdx + 1; // skip probe + result.insert(children[childIdx]->getSchema()->getGroupPos(*buildInfos[buildIdx]->keyNodeID)); + return result; +} + void LogicalIntersect::computeSchema() { auto probeSchema = children[0]->getSchema(); schema = probeSchema->copy(); diff --git a/src/planner/operator/logical_limit.cpp b/src/planner/operator/logical_limit.cpp index d73266a84bd..c799989e9b6 100644 --- a/src/planner/operator/logical_limit.cpp +++ b/src/planner/operator/logical_limit.cpp @@ -3,6 +3,12 @@ namespace kuzu { namespace planner { +f_group_pos_set LogicalLimit::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return factorization::FlattenAllButOne::getGroupsPosToFlatten( + childSchema->getGroupsPosInScope(), childSchema); +} + f_group_pos LogicalLimit::getGroupPosToSelect() const { auto childSchema = children[0]->getSchema(); auto groupsPosInScope = childSchema->getGroupsPosInScope(); diff --git a/src/planner/operator/logical_order_by.cpp b/src/planner/operator/logical_order_by.cpp index 94a3db862cb..00636f82fab 100644 --- a/src/planner/operator/logical_order_by.cpp +++ b/src/planner/operator/logical_order_by.cpp @@ -1,10 +1,38 @@ #include "planner/logical_plan/logical_operator/logical_order_by.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" #include "planner/logical_plan/logical_operator/sink_util.h" namespace kuzu { namespace planner { +f_group_pos_set LogicalOrderBy::getGroupsPosToFlatten() { + // We only allow orderby key(s) to be unflat, if they are all part of the same factorization + // group and there is no other factorized group in the schema, so any payload is also unflat + // and part of the same factorization group. The rationale for this limitation is this: (1) + // to keep both the frontend and orderby operators simpler, we want order by to not change + // the schema, so the input and output of order by should have the same factorization + // structure. (2) Because orderby needs to flatten the keys to sort, if a key column that is + // unflat is the input, we need to somehow flatten it in the factorized table. However + // whenever we can we want to avoid adding an explicit flatten operator as this makes us + // fall back to tuple-at-a-time processing. However in the specified limited case, we can + // give factorized table a set of unflat vectors (all in the same datachunk/factorization + // group), sort the table, and scan into unflat vectors, so the schema remains the same. In + // more complicated cases, e.g., when there are 2 factorization groups, FactorizedTable + // cannot read back a flat column into an unflat std::vector. + auto childSchema = children[0]->getSchema(); + if (childSchema->getNumGroups() > 1) { + f_group_pos_set dependentGroupsPos; + for (auto& expression : expressionsToOrderBy) { + for (auto groupPos : childSchema->getDependentGroupsPos(expression)) { + dependentGroupsPos.insert(groupPos); + } + } + return factorization::FlattenAll::getGroupsPosToFlatten(dependentGroupsPos, childSchema); + } + return f_group_pos_set{}; +} + void LogicalOrderBy::computeSchema() { auto childSchema = children[0]->getSchema(); schema = std::make_unique(); diff --git a/src/planner/operator/logical_set.cpp b/src/planner/operator/logical_set.cpp new file mode 100644 index 00000000000..51d5eb89c33 --- /dev/null +++ b/src/planner/operator/logical_set.cpp @@ -0,0 +1,22 @@ +#include "planner/logical_plan/logical_operator/logical_set.h" + +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + +namespace kuzu { +namespace planner { + +f_group_pos_set LogicalSetRelProperty::getGroupsPosToFlatten(uint32_t setItemIdx) { + f_group_pos_set result; + auto rel = rels[setItemIdx]; + auto rhs = setItems[setItemIdx].second; + auto childSchema = children[0]->getSchema(); + result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalIDProperty())); + result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalIDProperty())); + for (auto groupPos : childSchema->getDependentGroupsPos(rhs)) { + result.insert(groupPos); + } + return factorization::FlattenAll::getGroupsPosToFlatten(result, childSchema); +} + +} // namespace planner +} // namespace kuzu diff --git a/src/planner/operator/logical_skip.cpp b/src/planner/operator/logical_skip.cpp index 5819fd4fc0c..20f7464dffb 100644 --- a/src/planner/operator/logical_skip.cpp +++ b/src/planner/operator/logical_skip.cpp @@ -3,6 +3,12 @@ namespace kuzu { namespace planner { +f_group_pos_set LogicalSkip::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + return factorization::FlattenAllButOne::getGroupsPosToFlatten( + childSchema->getGroupsPosInScope(), childSchema); +} + f_group_pos LogicalSkip::getGroupPosToSelect() const { auto childSchema = children[0]->getSchema(); auto groupsPosInScope = childSchema->getGroupsPosInScope(); diff --git a/src/planner/operator/logical_union.cpp b/src/planner/operator/logical_union.cpp index 7fbb297d451..08fe932ca79 100644 --- a/src/planner/operator/logical_union.cpp +++ b/src/planner/operator/logical_union.cpp @@ -1,10 +1,23 @@ #include "planner/logical_plan/logical_operator/logical_union.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" #include "planner/logical_plan/logical_operator/sink_util.h" namespace kuzu { namespace planner { +f_group_pos_set LogicalUnion::getGroupsPosToFlatten(uint32_t childIdx) { + f_group_pos_set groupsPos; + auto childSchema = children[childIdx]->getSchema(); + for (auto i = 0u; i < expressionsToUnion.size(); ++i) { + if (requireFlatExpression(i)) { + auto expression = childSchema->getExpressionsInScope()[i]; + groupsPos.insert(childSchema->getGroupPos(*expression)); + } + } + return factorization::FlattenAll::getGroupsPosToFlatten(groupsPos, childSchema); +} + void LogicalUnion::computeSchema() { auto firstChildSchema = children[0]->getSchema(); schema = std::make_unique(); @@ -20,5 +33,16 @@ std::unique_ptr LogicalUnion::copy() { return make_unique(expressionsToUnion, std::move(copiedChildren)); } +bool LogicalUnion::requireFlatExpression(uint32_t expressionIdx) { + for (auto& child : children) { + auto childSchema = child->getSchema(); + auto expression = childSchema->getExpressionsInScope()[expressionIdx]; + if (childSchema->getGroup(expression)->isFlat()) { + return true; + } + } + return false; +} + } // namespace planner } // namespace kuzu diff --git a/src/planner/operator/logical_unwind.cpp b/src/planner/operator/logical_unwind.cpp index 5ec4345ba1d..d3aa0b63baf 100644 --- a/src/planner/operator/logical_unwind.cpp +++ b/src/planner/operator/logical_unwind.cpp @@ -1,8 +1,16 @@ #include "planner/logical_plan/logical_operator/logical_unwind.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" + namespace kuzu { namespace planner { +f_group_pos_set LogicalUnwind::getGroupsPosToFlatten() { + auto childSchema = children[0]->getSchema(); + auto dependentGroupsPos = childSchema->getDependentGroupsPos(expression); + return factorization::FlattenAll::getGroupsPosToFlatten(dependentGroupsPos, childSchema); +} + void LogicalUnwind::computeSchema() { copyChildSchema(0); auto groupPos = schema->createGroup(); diff --git a/src/planner/projection_planner.cpp b/src/planner/projection_planner.cpp index c2d9b222793..afa0495fa8a 100644 --- a/src/planner/projection_planner.cpp +++ b/src/planner/projection_planner.cpp @@ -1,6 +1,7 @@ #include "planner/projection_planner.h" #include "binder/expression/function_expression.h" +#include "planner/logical_plan/logical_operator/flatten_resolver.h" #include "planner/logical_plan/logical_operator/logical_aggregate.h" #include "planner/logical_plan/logical_operator/logical_limit.h" #include "planner/logical_plan/logical_operator/logical_multiplcity_reducer.h" @@ -90,6 +91,12 @@ void ProjectionPlanner::appendProjection( for (auto& expression : expressionsToProject) { queryPlanner->planSubqueryIfNecessary(expression, plan); } + for (auto& expression : expressionsToProject) { + auto dependentGroupsPos = plan.getSchema()->getDependentGroupsPos(expression); + auto groupsPosToFlatten = factorization::FlattenAllButOne::getGroupsPosToFlatten( + dependentGroupsPos, plan.getSchema()); + QueryPlanner::appendFlattens(groupsPosToFlatten, plan); + } auto projection = make_shared(expressionsToProject, plan.getLastOperator()); projection->computeSchema(); plan.setLastOperator(std::move(projection)); @@ -97,68 +104,22 @@ void ProjectionPlanner::appendProjection( void ProjectionPlanner::appendAggregate(const expression_vector& expressionsToGroupBy, const expression_vector& expressionsToAggregate, LogicalPlan& plan) { - bool hasDistinctFunc = false; - for (auto& expressionToAggregate : expressionsToAggregate) { - auto& functionExpression = (AggregateFunctionExpression&)*expressionToAggregate; - if (functionExpression.isDistinct()) { - hasDistinctFunc = true; - } - } - if (hasDistinctFunc) { // Flatten all inputs. - for (auto& expressionToGroupBy : expressionsToGroupBy) { - auto dependentGroupsPos = plan.getSchema()->getDependentGroupsPos(expressionToGroupBy); - QueryPlanner::appendFlattens(dependentGroupsPos, plan); - } - for (auto& expressionToAggregate : expressionsToAggregate) { - auto dependentGroupsPos = - plan.getSchema()->getDependentGroupsPos(expressionToAggregate); - QueryPlanner::appendFlattens(dependentGroupsPos, plan); - } - } else { - // Flatten all but one for ALL group by keys. - std::unordered_set groupByPoses; - for (auto& expressionToGroupBy : expressionsToGroupBy) { - auto dependentGroupsPos = plan.getSchema()->getDependentGroupsPos(expressionToGroupBy); - groupByPoses.insert(dependentGroupsPos.begin(), dependentGroupsPos.end()); - } - QueryPlanner::appendFlattensButOne(groupByPoses, plan); - if (expressionsToAggregate.size() > 1) { - for (auto& expressionToAggregate : expressionsToAggregate) { - auto dependentGroupsPos = - plan.getSchema()->getDependentGroupsPos(expressionToAggregate); - QueryPlanner::appendFlattens(dependentGroupsPos, plan); - } - } - } auto aggregate = make_shared( expressionsToGroupBy, expressionsToAggregate, plan.getLastOperator()); + QueryPlanner::appendFlattens(aggregate->getGroupsPosToFlattenForGroupBy(), plan); + aggregate->setChild(0, plan.getLastOperator()); + QueryPlanner::appendFlattens(aggregate->getGroupsPosToFlattenForAggregate(), plan); + aggregate->setChild(0, plan.getLastOperator()); aggregate->computeSchema(); plan.setLastOperator(std::move(aggregate)); } void ProjectionPlanner::appendOrderBy( const expression_vector& expressions, const std::vector& isAscOrders, LogicalPlan& plan) { - for (auto& expression : expressions) { - // We only allow orderby key(s) to be unflat, if they are all part of the same factorization - // group and there is no other factorized group in the schema, so any payload is also unflat - // and part of the same factorization group. The rationale for this limitation is this: (1) - // to keep both the frontend and orderby operators simpler, we want order by to not change - // the schema, so the input and output of order by should have the same factorization - // structure. (2) Because orderby needs to flatten the keys to sort, if a key column that is - // unflat is the input, we need to somehow flatten it in the factorized table. However - // whenever we can we want to avoid adding an explicit flatten operator as this makes us - // fall back to tuple-at-a-time processing. However in the specified limited case, we can - // give factorized table a set of unflat vectors (all in the same datachunk/factorization - // group), sort the table, and scan into unflat vectors, so the schema remains the same. In - // more complicated cases, e.g., when there are 2 factorization groups, FactorizedTable - // cannot read back a flat column into an unflat std::vector. - if (plan.getSchema()->getNumGroups() > 1) { - auto dependentGroupsPos = plan.getSchema()->getDependentGroupsPos(expression); - QueryPlanner::appendFlattens(dependentGroupsPos, plan); - } - } auto orderBy = make_shared(expressions, isAscOrders, plan.getSchema()->getExpressionsInScope(), plan.getLastOperator()); + QueryPlanner::appendFlattens(orderBy->getGroupsPosToFlatten(), plan); + orderBy->setChild(0, plan.getLastOperator()); orderBy->computeSchema(); plan.setLastOperator(std::move(orderBy)); } @@ -170,16 +131,18 @@ void ProjectionPlanner::appendMultiplicityReducer(LogicalPlan& plan) { } void ProjectionPlanner::appendLimit(uint64_t limitNumber, LogicalPlan& plan) { - QueryPlanner::appendFlattensButOne(plan.getSchema()->getGroupsPosInScope(), plan); auto limit = make_shared(limitNumber, plan.getLastOperator()); + QueryPlanner::appendFlattens(limit->getGroupsPosToFlatten(), plan); + limit->setChild(0, plan.getLastOperator()); limit->computeSchema(); plan.setCardinality(limitNumber); plan.setLastOperator(std::move(limit)); } void ProjectionPlanner::appendSkip(uint64_t skipNumber, LogicalPlan& plan) { - QueryPlanner::appendFlattensButOne(plan.getSchema()->getGroupsPosInScope(), plan); auto skip = make_shared(skipNumber, plan.getLastOperator()); + QueryPlanner::appendFlattens(skip->getGroupsPosToFlatten(), plan); + skip->setChild(0, plan.getLastOperator()); skip->computeSchema(); plan.setCardinality(plan.getCardinality() - skipNumber); plan.setLastOperator(std::move(skip)); diff --git a/src/planner/query_planner.cpp b/src/planner/query_planner.cpp index 0d845577e2e..7158a4a2111 100644 --- a/src/planner/query_planner.cpp +++ b/src/planner/query_planner.cpp @@ -194,9 +194,6 @@ void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphColle auto innerPlans = joinOrderEnumerator.enumerate(queryGraphCollection, predicates); auto bestInnerPlan = getBestPlan(std::move(innerPlans)); joinOrderEnumerator.exitSubquery(std::move(prevContext)); - for (auto& joinNodeID : joinNodeIDs) { - appendFlattenIfNecessary(joinNodeID, outerPlan); - } JoinOrderEnumerator::planLeftHashJoin(joinNodeIDs, outerPlan, *bestInnerPlan); } else { throw NotImplementedException("Correlated optional match is not supported."); @@ -286,62 +283,34 @@ void QueryPlanner::appendExpressionsScan(const expression_vector& expressions, L void QueryPlanner::appendDistinct( const expression_vector& expressionsToDistinct, LogicalPlan& plan) { - for (auto& expression : expressionsToDistinct) { - auto dependentGroupsPos = plan.getSchema()->getDependentGroupsPos(expression); - appendFlattens(dependentGroupsPos, plan); - } auto distinct = make_shared(expressionsToDistinct, plan.getLastOperator()); + QueryPlanner::appendFlattens(distinct->getGroupsPosToFlatten(), plan); + distinct->setChild(0, plan.getLastOperator()); distinct->computeSchema(); plan.setLastOperator(std::move(distinct)); } void QueryPlanner::appendUnwind(BoundUnwindClause& boundUnwindClause, LogicalPlan& plan) { - auto dependentGroupPos = - plan.getSchema()->getDependentGroupsPos(boundUnwindClause.getExpression()); - if (!dependentGroupPos.empty()) { - appendFlattens(dependentGroupPos, plan); - } - auto logicalUnwind = make_shared(boundUnwindClause.getExpression(), + auto unwind = make_shared(boundUnwindClause.getExpression(), boundUnwindClause.getAliasExpression(), plan.getLastOperator()); - logicalUnwind->computeSchema(); - plan.setLastOperator(logicalUnwind); + QueryPlanner::appendFlattens(unwind->getGroupsPosToFlatten(), plan); + unwind->setChild(0, plan.getLastOperator()); + unwind->computeSchema(); + plan.setLastOperator(unwind); } -void QueryPlanner::appendFlattens( - const std::unordered_set& groupsPos, LogicalPlan& plan) { - for (auto& groupPos : groupsPos) { +void QueryPlanner::appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan) { + for (auto groupPos : groupsPos) { appendFlattenIfNecessary(groupPos, plan); } } -uint32_t QueryPlanner::appendFlattensButOne( - const std::unordered_set& groupsPos, LogicalPlan& plan) { - if (groupsPos.empty()) { - // an expression may not depend on any group. E.g. COUNT(*). - return UINT32_MAX; - } - std::vector unFlatGroupsPos; - for (auto& groupPos : groupsPos) { - if (!plan.getSchema()->getGroup(groupPos)->isFlat()) { - unFlatGroupsPos.push_back(groupPos); - } - } - if (unFlatGroupsPos.empty()) { - return *groupsPos.begin(); - } - for (auto i = 1u; i < unFlatGroupsPos.size(); ++i) { - appendFlattenIfNecessary(unFlatGroupsPos[i], plan); - } - return unFlatGroupsPos[0]; -} - -void QueryPlanner::appendFlattenIfNecessary( - const std::shared_ptr& expression, LogicalPlan& plan) { - auto group = plan.getSchema()->getGroup(expression); +void QueryPlanner::appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan) { + auto group = plan.getSchema()->getGroup(groupPos); if (group->isFlat()) { return; } - auto flatten = make_shared(expression, plan.getLastOperator()); + auto flatten = make_shared(groupPos, plan.getLastOperator()); flatten->computeSchema(); // update cardinality estimation info plan.multiplyCardinality(group->getMultiplier()); @@ -350,9 +319,9 @@ void QueryPlanner::appendFlattenIfNecessary( void QueryPlanner::appendFilter(const std::shared_ptr& expression, LogicalPlan& plan) { planSubqueryIfNecessary(expression, plan); - auto dependentGroupsPos = plan.getSchema()->getDependentGroupsPos(expression); - appendFlattensButOne(dependentGroupsPos, plan); auto filter = make_shared(expression, plan.getLastOperator()); + QueryPlanner::appendFlattens(filter->getGroupsPosToFlatten(), plan); + filter->setChild(0, plan.getLastOperator()); filter->computeSchema(); plan.multiplyCardinality(EnumeratorKnobs::PREDICATE_SELECTIVITY); plan.setLastOperator(std::move(filter)); @@ -378,38 +347,24 @@ void QueryPlanner::appendScanNodePropIfNecessary(const expression_vector& proper std::unique_ptr QueryPlanner::createUnionPlan( std::vector>& childrenPlans, bool isUnionAll) { - // If an expression to union has different flat/unflat state in different child, we - // need to flatten that expression in all the single queries. assert(!childrenPlans.empty()); - auto numExpressionsToUnion = childrenPlans[0]->getSchema()->getExpressionsInScope().size(); - for (auto i = 0u; i < numExpressionsToUnion; i++) { - bool hasFlatExpression = false; - for (auto& childPlan : childrenPlans) { - auto childSchema = childPlan->getSchema(); - auto expressionName = childSchema->getExpressionsInScope()[i]->getUniqueName(); - hasFlatExpression |= childSchema->getGroup(expressionName)->isFlat(); - } - if (hasFlatExpression) { - for (auto& childPlan : childrenPlans) { - auto childSchema = childPlan->getSchema(); - auto expressionName = childSchema->getExpressionsInScope()[i]->getUniqueName(); - appendFlattenIfNecessary(childSchema->getGroupPos(expressionName), *childPlan); - } - } - } - // we compute the schema based on first child auto plan = std::make_unique(); std::vector> children; for (auto& childPlan : childrenPlans) { plan->increaseCost(childPlan->getCost()); children.push_back(childPlan->getLastOperator()); } - auto logicalUnion = make_shared( + // we compute the schema based on first child + auto union_ = make_shared( childrenPlans[0]->getSchema()->getExpressionsInScope(), std::move(children)); - logicalUnion->computeSchema(); - plan->setLastOperator(logicalUnion); + for (auto i = 0u; i < childrenPlans.size(); ++i) { + appendFlattens(union_->getGroupsPosToFlatten(i), *childrenPlans[i]); + union_->setChild(i, childrenPlans[i]->getLastOperator()); + } + union_->computeSchema(); + plan->setLastOperator(union_); if (!isUnionAll) { - appendDistinct(logicalUnion->getExpressionsToUnion(), *plan); + appendDistinct(union_->getExpressionsToUnion(), *plan); } return plan; } diff --git a/src/planner/update_planner.cpp b/src/planner/update_planner.cpp index d482904d025..28ec014e01e 100644 --- a/src/planner/update_planner.cpp +++ b/src/planner/update_planner.cpp @@ -43,11 +43,6 @@ void UpdatePlanner::planUpdatingClause(BoundUpdatingClause& updatingClause, Logi } void UpdatePlanner::planCreate(BoundCreateClause& createClause, LogicalPlan& plan) { - // Flatten all inputs. E.g. MATCH (a) CREATE (b). We need to create b for each tuple in the - // match clause. This is to simplify operator implementation. - for (auto groupPos = 0u; groupPos < plan.getSchema()->getNumGroups(); ++groupPos) { - QueryPlanner::appendFlattenIfNecessary(groupPos, plan); - } if (createClause.hasCreateNode()) { appendCreateNode(createClause.getCreateNodes(), plan); } @@ -71,6 +66,8 @@ void UpdatePlanner::appendCreateNode( } auto createNode = make_shared( std::move(nodes), std::move(primaryKeys), plan.getLastOperator()); + QueryPlanner::appendFlattens(createNode->getGroupsPosToFlatten(), plan); + createNode->setChild(0, plan.getLastOperator()); createNode->computeSchema(); plan.setLastOperator(createNode); appendSetNodeProperty(setNodeProperties, plan); @@ -86,6 +83,8 @@ void UpdatePlanner::appendCreateRel( } auto createRel = make_shared( std::move(rels), std::move(setItemsPerRel), plan.getLastOperator()); + QueryPlanner::appendFlattens(createRel->getGroupsPosToFlatten(), plan); + createRel->setChild(0, plan.getLastOperator()); createRel->computeSchema(); plan.setLastOperator(createRel); } @@ -105,21 +104,24 @@ void UpdatePlanner::appendSetNodeProperty( std::vector> nodes; std::vector setItems; for (auto& setNodeProperty : setNodeProperties) { - auto node = setNodeProperty->getNode(); - auto lhsGroupPos = plan.getSchema()->getGroupPos(node->getInternalIDPropertyName()); - auto isLhsFlat = plan.getSchema()->getGroup(lhsGroupPos)->isFlat(); - auto rhs = setNodeProperty->getSetItem().second; + nodes.push_back(setNodeProperty->getNode()); + setItems.push_back(setNodeProperty->getSetItem()); + } + for (auto i = 0u; i < setItems.size(); ++i) { + auto lhsNodeID = nodes[i]->getInternalIDProperty(); + auto rhs = setItems[i].second; + // flatten rhs auto rhsDependentGroupsPos = plan.getSchema()->getDependentGroupsPos(rhs); - if (!rhsDependentGroupsPos.empty()) { // RHS is not constant - auto rhsPos = QueryPlanner::appendFlattensButOne(rhsDependentGroupsPos, plan); - auto isRhsFlat = plan.getSchema()->getGroup(rhsPos)->isFlat(); - // If both are unflat and from different groups, we flatten LHS. - if (!isRhsFlat && !isLhsFlat && lhsGroupPos != rhsPos) { - QueryPlanner::appendFlattenIfNecessary(lhsGroupPos, plan); - } + auto rhsGroupsPosToFlatten = factorization::FlattenAllButOne::getGroupsPosToFlatten( + rhsDependentGroupsPos, plan.getSchema()); + QueryPlanner::appendFlattens(rhsGroupsPosToFlatten, plan); + // flatten lhs if needed + auto lhsGroupPos = plan.getSchema()->getGroupPos(*lhsNodeID); + auto rhsLeadingGroupPos = + SchemaUtils::getLeadingGroupPos(rhsDependentGroupsPos, *plan.getSchema()); + if (lhsGroupPos != rhsLeadingGroupPos) { + QueryPlanner::appendFlattenIfNecessary(lhsGroupPos, plan); } - nodes.push_back(node); - setItems.push_back(setNodeProperty->getSetItem()); } auto setNodeProperty = make_shared( std::move(nodes), std::move(setItems), plan.getLastOperator()); @@ -132,15 +134,15 @@ void UpdatePlanner::appendSetRelProperty( std::vector> rels; std::vector setItems; for (auto& setRelProperty : setRelProperties) { - flattenRel(*setRelProperty->getRel(), plan); - auto rhs = setRelProperty->getSetItem().second; - auto rhsDependentGroupsPos = plan.getSchema()->getDependentGroupsPos(rhs); - QueryPlanner::appendFlattens(rhsDependentGroupsPos, plan); rels.push_back(setRelProperty->getRel()); setItems.push_back(setRelProperty->getSetItem()); } auto setRelProperty = make_shared( std::move(rels), std::move(setItems), plan.getLastOperator()); + for (auto i = 0u; i < setRelProperty->getNumRels(); ++i) { + QueryPlanner::appendFlattens(setRelProperty->getGroupsPosToFlatten(i), plan); + setRelProperty->setChild(0, plan.getLastOperator()); + } setRelProperty->computeSchema(); plan.setLastOperator(setRelProperty); } @@ -170,21 +172,14 @@ void UpdatePlanner::appendDeleteNode( void UpdatePlanner::appendDeleteRel( const std::vector>& deleteRels, LogicalPlan& plan) { - // Delete one rel at a time so we flatten for each rel. - for (auto& rel : deleteRels) { - flattenRel(*rel, plan); - } auto deleteRel = make_shared(deleteRels, plan.getLastOperator()); + for (auto i = 0u; i < deleteRel->getNumRels(); ++i) { + QueryPlanner::appendFlattens(deleteRel->getGroupsPosToFlatten(i), plan); + deleteRel->setChild(0, plan.getLastOperator()); + } deleteRel->computeSchema(); plan.setLastOperator(std::move(deleteRel)); } -void UpdatePlanner::flattenRel(const RelExpression& rel, LogicalPlan& plan) { - auto srcNodeID = rel.getSrcNode()->getInternalIDProperty(); - QueryPlanner::appendFlattenIfNecessary(srcNodeID, plan); - auto dstNodeID = rel.getDstNode()->getInternalIDProperty(); - QueryPlanner::appendFlattenIfNecessary(dstNodeID, plan); -} - } // namespace planner } // namespace kuzu diff --git a/src/processor/mapper/map_flatten.cpp b/src/processor/mapper/map_flatten.cpp index 5bd6a5cc86d..2c87c96f1c6 100644 --- a/src/processor/mapper/map_flatten.cpp +++ b/src/processor/mapper/map_flatten.cpp @@ -10,11 +10,9 @@ namespace processor { std::unique_ptr PlanMapper::mapLogicalFlattenToPhysical( LogicalOperator* logicalOperator) { auto flatten = (LogicalFlatten*)logicalOperator; - auto inSchema = flatten->getChild(0)->getSchema(); auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); - auto dataChunkPos = inSchema->getExpressionPos(*flatten->getExpression()).first; - return make_unique( - dataChunkPos, move(prevOperator), getOperatorID(), flatten->getExpressionsForPrinting()); + return make_unique(flatten->getGroupPos(), std::move(prevOperator), getOperatorID(), + flatten->getExpressionsForPrinting()); } } // namespace processor diff --git a/test/test_files/tinysnb/projection/single_label.test b/test/test_files/tinysnb/projection/single_label.test index dee3d49105b..4f9266ed77b 100644 --- a/test/test_files/tinysnb/projection/single_label.test +++ b/test/test_files/tinysnb/projection/single_label.test @@ -132,6 +132,19 @@ False [[10]] [[7],[10],[6,7]] +-NAME CrossProductReturn +-QUERY MATCH (a:organisation), (b:organisation) RETURN a.orgCode = b.orgCode +---- 9 +False +False +False +False +False +False +True +True +True + -NAME KnowsOneHopTest1 -QUERY MATCH (a:person)-[e:knows]->(b:person) WHERE b.age=20 RETURN b.age ---- 3