diff --git a/src/binder/bind/bind_updating_clause.cpp b/src/binder/bind/bind_updating_clause.cpp index fc548701f41..f758734f895 100644 --- a/src/binder/bind/bind_updating_clause.cpp +++ b/src/binder/bind/bind_updating_clause.cpp @@ -279,18 +279,19 @@ std::unique_ptr Binder::bindSetClause(const UpdatingClause& BoundSetPropertyInfo Binder::bindSetPropertyInfo(parser::ParsedExpression* lhs, parser::ParsedExpression* rhs) { - auto pattern = expressionBinder.bindExpression(*lhs->getChild(0)); - auto isNode = ExpressionUtil::isNodePattern(*pattern); - auto isRel = ExpressionUtil::isRelPattern(*pattern); + auto expr = expressionBinder.bindExpression(*lhs->getChild(0)); + auto isNode = ExpressionUtil::isNodePattern(*expr); + auto isRel = ExpressionUtil::isRelPattern(*expr); if (!isNode && !isRel) { throw BinderException( stringFormat("Cannot set expression {} with type {}. Expect node or rel pattern.", - pattern->toString(), expressionTypeToString(pattern->expressionType))); + expr->toString(), expressionTypeToString(expr->expressionType))); } - auto patternExpr = ku_dynamic_cast(pattern.get()); + auto& patternExpr = expr->constCast(); auto boundSetItem = bindSetItem(lhs, rhs); + // Validate not updating tables belong to RDFGraph. auto catalog = clientContext->getCatalog(); - for (auto tableID : patternExpr->getTableIDs()) { + for (auto tableID : patternExpr.getTableIDs()) { auto tableName = catalog->getTableCatalogEntry(clientContext->getTx(), tableID)->getName(); for (auto& rdfGraphEntry : catalog->getRdfGraphEntries(clientContext->getTx())) { if (rdfGraphEntry->isParent(tableID)) { @@ -301,8 +302,20 @@ BoundSetPropertyInfo Binder::bindSetPropertyInfo(parser::ParsedExpression* lhs, } } } - return BoundSetPropertyInfo(isNode ? UpdateTableType::NODE : UpdateTableType::REL, pattern, - std::move(boundSetItem)); + if (isNode) { + auto info = BoundSetPropertyInfo(TableType::NODE, expr, boundSetItem); + auto& property = boundSetItem.first->constCast(); + for (auto id : patternExpr.getTableIDs()) { + if (property.isPrimaryKey(id)) { + info.pkExpr = boundSetItem.first; + if (info.pkExpr->dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) { + throw BinderException("Updating SERIAL primary key is not supported."); + } + } + } + return info; + } + return BoundSetPropertyInfo(TableType::REL, expr, std::move(boundSetItem)); } expression_pair Binder::bindSetItem(parser::ParsedExpression* lhs, parser::ParsedExpression* rhs) { diff --git a/src/binder/query/bound_merge_clause.cpp b/src/binder/query/bound_merge_clause.cpp index 3ac88c8281e..ab1d27a4017 100644 --- a/src/binder/query/bound_merge_clause.cpp +++ b/src/binder/query/bound_merge_clause.cpp @@ -36,12 +36,12 @@ bool BoundMergeClause::hasOnMatchSetInfo( return false; } -std::vector BoundMergeClause::getOnMatchSetInfos( +std::vector BoundMergeClause::getOnMatchSetInfos( const std::function& check) const { - std::vector result; + std::vector result; for (auto& info : onMatchSetPropertyInfos) { if (check(info)) { - result.push_back(&info); + result.push_back(info.copy()); } } return result; @@ -57,12 +57,12 @@ bool BoundMergeClause::hasOnCreateSetInfo( return false; } -std::vector BoundMergeClause::getOnCreateSetInfos( +std::vector BoundMergeClause::getOnCreateSetInfos( const std::function& check) const { - std::vector result; + std::vector result; for (auto& info : onCreateSetPropertyInfos) { if (check(info)) { - result.push_back(&info); + result.push_back(info.copy()); } } return result; diff --git a/src/binder/query/bound_set_clause.cpp b/src/binder/query/bound_set_clause.cpp index 013382fb7bf..821e80b57f0 100644 --- a/src/binder/query/bound_set_clause.cpp +++ b/src/binder/query/bound_set_clause.cpp @@ -14,12 +14,12 @@ bool BoundSetClause::hasInfo(const std::function BoundSetClause::getInfos( +std::vector BoundSetClause::getInfos( const std::function& check) const { - std::vector result; + std::vector result; for (auto& info : infos) { if (check(info)) { - result.push_back(&info); + result.push_back(info.copy()); } } return result; diff --git a/src/binder/visitor/property_collector.cpp b/src/binder/visitor/property_collector.cpp index 2de56b4c6bc..1a26bafac5a 100644 --- a/src/binder/visitor/property_collector.cpp +++ b/src/binder/visitor/property_collector.cpp @@ -60,8 +60,11 @@ void PropertyCollector::visitTableFunctionCall(const BoundReadingClause& reading } void PropertyCollector::visitSet(const BoundUpdatingClause& updatingClause) { - auto& boundSetClause = (BoundSetClause&)updatingClause; - for (auto& info : boundSetClause.getInfosRef()) { + auto& boundSetClause = updatingClause.constCast(); + for (auto& info : boundSetClause.getInfos()) { + if (info.pkExpr != nullptr) { + properties.insert(info.pkExpr); + } collectPropertyExpressions(info.setItem.second); } } diff --git a/src/include/binder/query/updating_clause/bound_merge_clause.h b/src/include/binder/query/updating_clause/bound_merge_clause.h index 4e24fd8bf17..bc6c804e5a2 100644 --- a/src/include/binder/query/updating_clause/bound_merge_clause.h +++ b/src/include/binder/query/updating_clause/bound_merge_clause.h @@ -52,43 +52,43 @@ class BoundMergeClause : public BoundUpdatingClause { bool hasOnMatchSetNodeInfo() const { return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::NODE; + return info.tableType == common::TableType::NODE; }); } - std::vector getOnMatchSetNodeInfos() const { + std::vector getOnMatchSetNodeInfos() const { return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::NODE; + return info.tableType == common::TableType::NODE; }); } bool hasOnMatchSetRelInfo() const { return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::REL; + return info.tableType == common::TableType::REL; }); } - std::vector getOnMatchSetRelInfos() const { + std::vector getOnMatchSetRelInfos() const { return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::REL; + return info.tableType == common::TableType::REL; }); } bool hasOnCreateSetNodeInfo() const { return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::NODE; + return info.tableType == common::TableType::NODE; }); } - std::vector getOnCreateSetNodeInfos() const { + std::vector getOnCreateSetNodeInfos() const { return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::NODE; + return info.tableType == common::TableType::NODE; }); } bool hasOnCreateSetRelInfo() const { return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::REL; + return info.tableType == common::TableType::REL; }); } - std::vector getOnCreateSetRelInfos() const { + std::vector getOnCreateSetRelInfos() const { return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::REL; + return info.tableType == common::TableType::REL; }); } @@ -106,12 +106,12 @@ class BoundMergeClause : public BoundUpdatingClause { bool hasOnMatchSetInfo( const std::function& check) const; - std::vector getOnMatchSetInfos( + std::vector getOnMatchSetInfos( const std::function& check) const; bool hasOnCreateSetInfo( const std::function& check) const; - std::vector getOnCreateSetInfos( + std::vector getOnCreateSetInfos( const std::function& check) const; private: diff --git a/src/include/binder/query/updating_clause/bound_set_clause.h b/src/include/binder/query/updating_clause/bound_set_clause.h index afe505fd269..d700d75f773 100644 --- a/src/include/binder/query/updating_clause/bound_set_clause.h +++ b/src/include/binder/query/updating_clause/bound_set_clause.h @@ -10,33 +10,33 @@ class BoundSetClause : public BoundUpdatingClause { public: BoundSetClause() : BoundUpdatingClause{common::ClauseType::SET} {} - inline void addInfo(BoundSetPropertyInfo info) { infos.push_back(std::move(info)); } - inline const std::vector& getInfosRef() { return infos; } + void addInfo(BoundSetPropertyInfo info) { infos.push_back(std::move(info)); } + const std::vector& getInfos() const { return infos; } - inline bool hasNodeInfo() const { + bool hasNodeInfo() const { return hasInfo([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::NODE; + return info.tableType == common::TableType::NODE; }); } - inline std::vector getNodeInfos() const { + std::vector getNodeInfos() const { return getInfos([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::NODE; + return info.tableType == common::TableType::NODE; }); } - inline bool hasRelInfo() const { + bool hasRelInfo() const { return hasInfo([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::REL; + return info.tableType == common::TableType::REL; }); } - inline std::vector getRelInfos() const { + std::vector getRelInfos() const { return getInfos([](const BoundSetPropertyInfo& info) { - return info.updateTableType == UpdateTableType::REL; + return info.tableType == common::TableType::REL; }); } private: bool hasInfo(const std::function& check) const; - std::vector getInfos( + std::vector getInfos( const std::function& check) const; private: diff --git a/src/include/binder/query/updating_clause/bound_set_info.h b/src/include/binder/query/updating_clause/bound_set_info.h index e2e8d43b76e..dd15183611d 100644 --- a/src/include/binder/query/updating_clause/bound_set_info.h +++ b/src/include/binder/query/updating_clause/bound_set_info.h @@ -1,26 +1,26 @@ #pragma once #include "binder/expression/expression.h" -#include "update_table_type.h" +#include "common/enums/table_type.h" namespace kuzu { namespace binder { struct BoundSetPropertyInfo { - UpdateTableType updateTableType; - std::shared_ptr nodeOrRel; + common::TableType tableType; + std::shared_ptr pattern; expression_pair setItem; + std::shared_ptr pkExpr = nullptr; - BoundSetPropertyInfo(UpdateTableType updateTableType, std::shared_ptr nodeOrRel, + BoundSetPropertyInfo(common::TableType tableType, std::shared_ptr pattern, expression_pair setItem) - : updateTableType{updateTableType}, nodeOrRel{std::move(nodeOrRel)}, - setItem{std::move(setItem)} {} + : tableType{tableType}, pattern{std::move(pattern)}, setItem{std::move(setItem)} {} EXPLICIT_COPY_DEFAULT_MOVE(BoundSetPropertyInfo); private: BoundSetPropertyInfo(const BoundSetPropertyInfo& other) - : updateTableType{other.updateTableType}, nodeOrRel{other.nodeOrRel}, - setItem{other.setItem} {} + : tableType{other.tableType}, pattern{other.pattern}, setItem{other.setItem}, + pkExpr{other.pkExpr} {} }; } // namespace binder diff --git a/src/include/binder/query/updating_clause/update_table_type.h b/src/include/binder/query/updating_clause/update_table_type.h deleted file mode 100644 index 632cc715a0d..00000000000 --- a/src/include/binder/query/updating_clause/update_table_type.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include - -namespace kuzu { -namespace binder { - -enum class UpdateTableType : uint8_t { - NODE = 0, - REL = 1, -}; - -} -} // namespace kuzu diff --git a/src/include/optimizer/factorization_rewriter.h b/src/include/optimizer/factorization_rewriter.h index 0b62d51066d..d5e2761eeb8 100644 --- a/src/include/optimizer/factorization_rewriter.h +++ b/src/include/optimizer/factorization_rewriter.h @@ -27,8 +27,7 @@ class FactorizationRewriter final : public LogicalOperatorVisitor { void visitUnwind(planner::LogicalOperator* op) override; void visitUnion(planner::LogicalOperator* op) override; void visitFilter(planner::LogicalOperator* op) override; - void visitSetNodeProperty(planner::LogicalOperator* op) override; - void visitSetRelProperty(planner::LogicalOperator* op) override; + void visitSetProperty(planner::LogicalOperator* op) override; void visitDelete(planner::LogicalOperator* op) override; void visitInsert(planner::LogicalOperator* op) override; void visitMerge(planner::LogicalOperator* op) override; diff --git a/src/include/optimizer/logical_operator_visitor.h b/src/include/optimizer/logical_operator_visitor.h index b1876335235..0e4ee3f5e91 100644 --- a/src/include/optimizer/logical_operator_visitor.h +++ b/src/include/optimizer/logical_operator_visitor.h @@ -135,14 +135,8 @@ class LogicalOperatorVisitor { return op; } - virtual void visitSetNodeProperty(planner::LogicalOperator* /*op*/) {} - virtual std::shared_ptr visitSetNodePropertyReplace( - std::shared_ptr op) { - return op; - } - - virtual void visitSetRelProperty(planner::LogicalOperator* /*op*/) {} - virtual std::shared_ptr visitSetRelPropertyReplace( + virtual void visitSetProperty(planner::LogicalOperator*) {} + virtual std::shared_ptr visitSetPropertyReplace( std::shared_ptr op) { return op; } diff --git a/src/include/optimizer/projection_push_down_optimizer.h b/src/include/optimizer/projection_push_down_optimizer.h index 05012aee937..91f259dcd30 100644 --- a/src/include/optimizer/projection_push_down_optimizer.h +++ b/src/include/optimizer/projection_push_down_optimizer.h @@ -7,7 +7,9 @@ namespace kuzu { namespace main { class ClientContext; } - +namespace binder { +struct BoundSetPropertyInfo; +} namespace planner { struct LogicalInsertInfo; } @@ -35,14 +37,15 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor { void visitProjection(planner::LogicalOperator* op) override; void visitOrderBy(planner::LogicalOperator* op) override; void visitUnwind(planner::LogicalOperator* op) override; - void visitSetNodeProperty(planner::LogicalOperator* op) override; - void visitSetRelProperty(planner::LogicalOperator* op) override; + void visitSetProperty(planner::LogicalOperator* op) override; void visitInsert(planner::LogicalOperator* op) override; - void visitInsertInfo(const planner::LogicalInsertInfo* info); void visitDelete(planner::LogicalOperator* op) override; void visitMerge(planner::LogicalOperator* op) override; void visitCopyFrom(planner::LogicalOperator* op) override; + void visitSetInfo(const binder::BoundSetPropertyInfo& info); + void visitInsertInfo(const planner::LogicalInsertInfo& info); + void collectExpressionsInUse(std::shared_ptr expression); binder::expression_vector pruneExpressions(const binder::expression_vector& expressions); diff --git a/src/include/planner/operator/logical_operator.h b/src/include/planner/operator/logical_operator.h index 4454a37c8c5..40c3bd4064c 100644 --- a/src/include/planner/operator/logical_operator.h +++ b/src/include/planner/operator/logical_operator.h @@ -53,8 +53,7 @@ enum class LogicalOperatorType : uint8_t { SCAN_FRONTIER, SCAN_NODE_TABLE, SEMI_MASKER, - SET_NODE_PROPERTY, - SET_REL_PROPERTY, + SET_PROPERTY, STANDALONE_CALL, TABLE_FUNCTION_CALL, TRANSACTION, diff --git a/src/include/planner/operator/persistent/logical_merge.h b/src/include/planner/operator/persistent/logical_merge.h index 213ae51173c..cd0a5dc2d8b 100644 --- a/src/include/planner/operator/persistent/logical_merge.h +++ b/src/include/planner/operator/persistent/logical_merge.h @@ -1,30 +1,20 @@ #pragma once +#include "binder/query/updating_clause/bound_set_info.h" #include "planner/operator/logical_operator.h" #include "planner/operator/persistent/logical_insert.h" -#include "planner/operator/persistent/logical_set.h" namespace kuzu { namespace planner { -class LogicalMerge : public LogicalOperator { +class LogicalMerge final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::MERGE; + public: LogicalMerge(std::shared_ptr existenceMark, - std::shared_ptr distinctMark, - std::vector insertNodeInfos, - std::vector insertRelInfos, - std::vector> onCreateSetNodeInfos, - std::vector> onCreateSetRelInfos, - std::vector> onMatchSetNodeInfos, - std::vector> onMatchSetRelInfos, - std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::MERGE, std::move(child)}, - existenceMark{std::move(existenceMark)}, distinctMark{std::move(distinctMark)}, - insertNodeInfos{std::move(insertNodeInfos)}, insertRelInfos{std::move(insertRelInfos)}, - onCreateSetNodeInfos{std::move(onCreateSetNodeInfos)}, - onCreateSetRelInfos(std::move(onCreateSetRelInfos)), - onMatchSetNodeInfos{std::move(onMatchSetNodeInfos)}, - onMatchSetRelInfos{std::move(onMatchSetRelInfos)} {} + std::shared_ptr distinctMark, std::shared_ptr child) + : LogicalOperator{type_, std::move(child)}, existenceMark{std::move(existenceMark)}, + distinctMark{std::move(distinctMark)} {} void computeFactorizedSchema() final; void computeFlatSchema() final; @@ -37,29 +27,41 @@ class LogicalMerge : public LogicalOperator { bool hasDistinctMark() const { return distinctMark != nullptr; } std::shared_ptr getDistinctMark() const { return distinctMark; } - const std::vector& getInsertNodeInfosRef() const { return insertNodeInfos; } - const std::vector& getInsertRelInfosRef() const { return insertRelInfos; } - const std::vector>& getOnCreateSetNodeInfosRef() const { + void addInsertNodeInfo(LogicalInsertInfo info) { insertNodeInfos.push_back(std::move(info)); } + const std::vector& getInsertNodeInfos() const { return insertNodeInfos; } + + void addInsertRelInfo(LogicalInsertInfo info) { insertRelInfos.push_back(std::move(info)); } + const std::vector& getInsertRelInfos() const { return insertRelInfos; } + + void addOnCreateSetNodeInfo(binder::BoundSetPropertyInfo info) { + onCreateSetNodeInfos.push_back(std::move(info)); + } + const std::vector& getOnCreateSetNodeInfos() const { return onCreateSetNodeInfos; } - const std::vector>& getOnCreateSetRelInfosRef() const { + + void addOnCreateSetRelInfo(binder::BoundSetPropertyInfo info) { + onCreateSetRelInfos.push_back(std::move(info)); + } + const std::vector& getOnCreateSetRelInfos() const { return onCreateSetRelInfos; } - const std::vector>& getOnMatchSetNodeInfosRef() const { + + void addOnMatchSetNodeInfo(binder::BoundSetPropertyInfo info) { + onMatchSetNodeInfos.push_back(std::move(info)); + } + const std::vector& getOnMatchSetNodeInfos() const { return onMatchSetNodeInfos; } - const std::vector>& getOnMatchSetRelInfosRef() const { + + void addOnMatchSetRelInfo(binder::BoundSetPropertyInfo info) { + onMatchSetRelInfos.push_back(std::move(info)); + } + const std::vector& getOnMatchSetRelInfos() const { return onMatchSetRelInfos; } - std::unique_ptr copy() final { - return std::make_unique(existenceMark, distinctMark, - copyVector(insertNodeInfos), copyVector(insertRelInfos), - LogicalSetPropertyInfo::copy(onCreateSetNodeInfos), - LogicalSetPropertyInfo::copy(onCreateSetRelInfos), - LogicalSetPropertyInfo::copy(onMatchSetNodeInfos), - LogicalSetPropertyInfo::copy(onMatchSetRelInfos), children[0]->copy()); - } + std::unique_ptr copy() override; private: std::shared_ptr existenceMark; @@ -68,11 +70,11 @@ class LogicalMerge : public LogicalOperator { std::vector insertNodeInfos; std::vector insertRelInfos; // On Create infos - std::vector> onCreateSetNodeInfos; - std::vector> onCreateSetRelInfos; + std::vector onCreateSetNodeInfos; + std::vector onCreateSetRelInfos; // On Match infos - std::vector> onMatchSetNodeInfos; - std::vector> onMatchSetRelInfos; + std::vector onMatchSetNodeInfos; + std::vector onMatchSetRelInfos; }; } // namespace planner diff --git a/src/include/planner/operator/persistent/logical_set.h b/src/include/planner/operator/persistent/logical_set.h index ab2455dd864..03f07111379 100644 --- a/src/include/planner/operator/persistent/logical_set.h +++ b/src/include/planner/operator/persistent/logical_set.h @@ -1,80 +1,35 @@ #pragma once +#include "binder/query/updating_clause/bound_set_info.h" #include "planner/operator/logical_operator.h" namespace kuzu { namespace planner { -struct LogicalSetPropertyInfo { - std::shared_ptr nodeOrRel; - binder::expression_pair setItem; +class LogicalSetProperty final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::SET_PROPERTY; - LogicalSetPropertyInfo(std::shared_ptr nodeOrRel, - binder::expression_pair setItem) - : nodeOrRel{std::move(nodeOrRel)}, setItem{std::move(setItem)} {} - LogicalSetPropertyInfo(const LogicalSetPropertyInfo& other) - : nodeOrRel{other.nodeOrRel}, setItem{other.setItem} {} - - inline std::unique_ptr copy() const { - return std::make_unique(*this); - } - - static std::vector> copy( - const std::vector>& infos); -}; - -class LogicalSetNodeProperty : public LogicalOperator { public: - LogicalSetNodeProperty(std::vector> infos, + LogicalSetProperty(std::vector infos, std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::SET_NODE_PROPERTY, std::move(child)}, - infos{std::move(infos)} {} + : LogicalOperator{type_, std::move(child)}, infos{std::move(infos)} {} - inline void computeFactorizedSchema() final { copyChildSchema(0); } - inline void computeFlatSchema() final { copyChildSchema(0); } + void computeFactorizedSchema() override; + void computeFlatSchema() override; - inline const std::vector>& getInfosRef() const { - return infos; - } - - f_group_pos_set getGroupsPosToFlatten(uint32_t idx); - - std::string getExpressionsForPrinting() const final; - - inline std::unique_ptr copy() final { - return std::make_unique(LogicalSetPropertyInfo::copy(infos), - children[0]->copy()); - } - -private: - std::vector> infos; -}; - -class LogicalSetRelProperty : public LogicalOperator { -public: - LogicalSetRelProperty(std::vector> infos, - std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::SET_REL_PROPERTY, std::move(child)}, - infos{std::move(infos)} {} - - inline void computeFactorizedSchema() final { copyChildSchema(0); } - inline void computeFlatSchema() final { copyChildSchema(0); } - - inline const std::vector>& getInfosRef() const { - return infos; - } + f_group_pos_set getGroupsPosToFlatten(uint32_t idx) const; - f_group_pos_set getGroupsPosToFlatten(uint32_t idx); + std::string getExpressionsForPrinting() const override; - std::string getExpressionsForPrinting() const final; + common::TableType getTableType() const; + const std::vector& getInfos() const { return infos; } - inline std::unique_ptr copy() final { - return std::make_unique(LogicalSetPropertyInfo::copy(infos), - children[0]->copy()); + std::unique_ptr copy() override { + return std::make_unique(copyVector(infos), children[0]->copy()); } private: - std::vector> infos; + std::vector infos; }; } // namespace planner diff --git a/src/include/planner/planner.h b/src/include/planner/planner.h index 510369542ed..dc7dd028351 100644 --- a/src/include/planner/planner.h +++ b/src/include/planner/planner.h @@ -28,7 +28,6 @@ class BoundProjectionBody; namespace planner { struct LogicalInsertInfo; -struct LogicalSetPropertyInfo; class Planner { public: @@ -194,14 +193,11 @@ class Planner { LogicalPlan& plan); void appendInsertRel(const std::vector& boundInsertInfos, LogicalPlan& plan); - void appendSetNodeProperty(const std::vector& boundInfos, - LogicalPlan& plan); - void appendSetRelProperty(const std::vector& boundInfos, + + void appendSetProperty(const std::vector& infos, LogicalPlan& plan); void appendDelete(const std::vector& infos, LogicalPlan& plan); std::unique_ptr createLogicalInsertInfo(const binder::BoundInsertInfo* info); - std::unique_ptr createLogicalSetPropertyInfo( - const binder::BoundSetPropertyInfo* boundSetPropertyInfo); // Append projection operators void appendProjection(const binder::expression_vector& expressionsToProject, LogicalPlan& plan); diff --git a/src/include/processor/operator/persistent/delete.h b/src/include/processor/operator/persistent/delete.h index edbf9264c01..5a76d71ee0e 100644 --- a/src/include/processor/operator/persistent/delete.h +++ b/src/include/processor/operator/persistent/delete.h @@ -7,10 +7,12 @@ namespace kuzu { namespace processor { class DeleteNode : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DELETE_; + public: DeleteNode(std::vector> executors, std::unique_ptr child, uint32_t id, const std::string& paramsString) - : PhysicalOperator{PhysicalOperatorType::DELETE_, std::move(child), id, paramsString}, + : PhysicalOperator{type_, std::move(child), id, paramsString}, executors{std::move(executors)} {} bool isParallel() const final { return false; } @@ -26,10 +28,12 @@ class DeleteNode : public PhysicalOperator { }; class DeleteRel : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::DELETE_; + public: DeleteRel(std::vector> executors, std::unique_ptr child, uint32_t id, const std::string& paramsString) - : PhysicalOperator{PhysicalOperatorType::DELETE_, std::move(child), id, paramsString}, + : PhysicalOperator{type_, std::move(child), id, paramsString}, executors{std::move(executors)} {} bool isParallel() const final { return false; } diff --git a/src/include/processor/operator/persistent/merge.h b/src/include/processor/operator/persistent/merge.h index 49058823189..94c031b7892 100644 --- a/src/include/processor/operator/persistent/merge.h +++ b/src/include/processor/operator/persistent/merge.h @@ -26,19 +26,13 @@ class Merge : public PhysicalOperator { onMatchNodeSetExecutors{std::move(onMatchNodeSetExecutors)}, onMatchRelSetExecutors{std::move(onMatchRelSetExecutors)} {} - inline bool isParallel() const final { return false; } + bool isParallel() const final { return false; } void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final; bool getNextTuplesInternal(ExecutionContext* context) final; - inline std::unique_ptr clone() final { - return std::make_unique(existenceMark, distinctMark, copyVector(nodeInsertExecutors), - copyVector(relInsertExecutors), NodeSetExecutor::copy(onCreateNodeSetExecutors), - RelSetExecutor::copy(onCreateRelSetExecutors), - NodeSetExecutor::copy(onMatchNodeSetExecutors), - RelSetExecutor::copy(onMatchRelSetExecutors), children[0]->clone(), id, paramsString); - } + std::unique_ptr clone() final; private: DataPos existenceMark; diff --git a/src/include/processor/operator/persistent/set.h b/src/include/processor/operator/persistent/set.h index be7fd3f56f9..18c2e47f796 100644 --- a/src/include/processor/operator/persistent/set.h +++ b/src/include/processor/operator/persistent/set.h @@ -7,34 +7,33 @@ namespace kuzu { namespace processor { class SetNodeProperty : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SET_PROPERTY; + public: SetNodeProperty(std::vector> executors, std::unique_ptr child, uint32_t id, const std::string& paramsString) - : PhysicalOperator{PhysicalOperatorType::SET_NODE_PROPERTY, std::move(child), id, - paramsString}, + : PhysicalOperator{type_, std::move(child), id, paramsString}, executors{std::move(executors)} {} - inline bool isParallel() const final { return false; } + bool isParallel() const final { return false; } void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final; bool getNextTuplesInternal(ExecutionContext* context) final; - inline std::unique_ptr clone() final { - return make_unique(NodeSetExecutor::copy(executors), children[0]->clone(), - id, paramsString); - } + std::unique_ptr clone() final; private: std::vector> executors; }; class SetRelProperty : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::SET_PROPERTY; + public: SetRelProperty(std::vector> executors, std::unique_ptr child, uint32_t id, const std::string& paramsString) - : PhysicalOperator{PhysicalOperatorType::SET_NODE_PROPERTY, std::move(child), id, - paramsString}, + : PhysicalOperator{type_, std::move(child), id, paramsString}, executors{std::move(executors)} {} inline bool isParallel() const final { return false; } @@ -43,10 +42,7 @@ class SetRelProperty : public PhysicalOperator { bool getNextTuplesInternal(ExecutionContext* context) final; - inline std::unique_ptr clone() final { - return make_unique(RelSetExecutor::copy(executors), children[0]->clone(), - id, paramsString); - } + std::unique_ptr clone() final; private: std::vector> executors; diff --git a/src/include/processor/operator/persistent/set_executor.h b/src/include/processor/operator/persistent/set_executor.h index 48c6a34c029..7b33f1d72ce 100644 --- a/src/include/processor/operator/persistent/set_executor.h +++ b/src/include/processor/operator/persistent/set_executor.h @@ -9,11 +9,25 @@ namespace kuzu { namespace processor { +struct NodeSetInfo { + DataPos nodeIDPos; + DataPos lhsPos; + DataPos pkPos = DataPos::getInvalidPos(); + + NodeSetInfo(DataPos nodeIDPos, DataPos lhsPos) : nodeIDPos{nodeIDPos}, lhsPos{lhsPos} {} + EXPLICIT_COPY_DEFAULT_MOVE(NodeSetInfo); + +private: + NodeSetInfo(const NodeSetInfo& other) + : nodeIDPos{other.nodeIDPos}, lhsPos{other.lhsPos}, pkPos{other.pkPos} {} +}; + class NodeSetExecutor { public: - NodeSetExecutor(const DataPos& nodeIDPos, const DataPos& lhsVectorPos, - std::unique_ptr evaluator) - : nodeIDPos{nodeIDPos}, lhsVectorPos{lhsVectorPos}, evaluator{std::move(evaluator)} {} + NodeSetExecutor(NodeSetInfo info, std::unique_ptr evaluator) + : info{std::move(info)}, evaluator{std::move(evaluator)} {} + NodeSetExecutor(const NodeSetExecutor& other) + : info{other.info.copy()}, evaluator{other.evaluator->clone()} {} virtual ~NodeSetExecutor() = default; void init(ResultSet* resultSet, ExecutionContext* context); @@ -23,11 +37,10 @@ class NodeSetExecutor { virtual std::unique_ptr copy() const = 0; static std::vector> copy( - const std::vector>& executors); + const std::vector>& others); protected: - DataPos nodeIDPos; - DataPos lhsVectorPos; + NodeSetInfo info; std::unique_ptr evaluator; common::ValueVector* nodeIDVector = nullptr; @@ -35,52 +48,60 @@ class NodeSetExecutor { // (rhs) common::ValueVector* lhsVector = nullptr; common::ValueVector* rhsVector = nullptr; + common::ValueVector* pkVector = nullptr; }; -struct NodeSetInfo { +struct ExtraNodeSetInfo { storage::NodeTable* table; common::column_id_t columnID; + + ExtraNodeSetInfo(storage::NodeTable* table, common::column_id_t columnID) + : table{table}, columnID{columnID} {} + EXPLICIT_COPY_DEFAULT_MOVE(ExtraNodeSetInfo); + +private: + ExtraNodeSetInfo(const ExtraNodeSetInfo& other) + : table{other.table}, columnID{other.columnID} {} }; class SingleLabelNodeSetExecutor final : public NodeSetExecutor { public: - SingleLabelNodeSetExecutor(NodeSetInfo setInfo, const DataPos& nodeIDPos, - const DataPos& lhsVectorPos, std::unique_ptr evaluator) - : NodeSetExecutor{nodeIDPos, lhsVectorPos, std::move(evaluator)}, setInfo{setInfo} {} + SingleLabelNodeSetExecutor(NodeSetInfo setInfo, + std::unique_ptr evaluator, ExtraNodeSetInfo extraInfo) + : NodeSetExecutor{std::move(setInfo), std::move(evaluator)}, + extraInfo{std::move(extraInfo)} {} SingleLabelNodeSetExecutor(const SingleLabelNodeSetExecutor& other) - : NodeSetExecutor{other.nodeIDPos, other.lhsVectorPos, other.evaluator->clone()}, - setInfo(other.setInfo) {} + : NodeSetExecutor{other}, extraInfo(other.extraInfo.copy()) {} void set(ExecutionContext* context) override; - inline std::unique_ptr copy() const override { + std::unique_ptr copy() const override { return std::make_unique(*this); } private: - NodeSetInfo setInfo; + ExtraNodeSetInfo extraInfo; }; class MultiLabelNodeSetExecutor final : public NodeSetExecutor { public: - MultiLabelNodeSetExecutor(std::unordered_map tableIDToSetInfo, - const DataPos& nodeIDPos, const DataPos& lhsVectorPos, - std::unique_ptr evaluator) - : NodeSetExecutor{nodeIDPos, lhsVectorPos, std::move(evaluator)}, - tableIDToSetInfo{std::move(tableIDToSetInfo)} {} + MultiLabelNodeSetExecutor(NodeSetInfo info, + std::unique_ptr evaluator, + common::table_id_map_t extraInfos) + : NodeSetExecutor{std::move(info), std::move(evaluator)}, + extraInfos{std::move(extraInfos)} {} MultiLabelNodeSetExecutor(const MultiLabelNodeSetExecutor& other) - : NodeSetExecutor{other.nodeIDPos, other.lhsVectorPos, other.evaluator->clone()}, - tableIDToSetInfo{other.tableIDToSetInfo} {} + : NodeSetExecutor{other}, extraInfos{copyMap(other.extraInfos)} {} void set(ExecutionContext* context) override; - inline std::unique_ptr copy() const override { + std::unique_ptr copy() const override { return std::make_unique(*this); } private: - std::unordered_map tableIDToSetInfo; + common::table_id_map_t extraInfos; }; class RelSetExecutor { @@ -130,7 +151,7 @@ class SingleLabelRelSetExecutor final : public RelSetExecutor { void set(ExecutionContext* context) override; - inline std::unique_ptr copy() const override { + std::unique_ptr copy() const override { return std::make_unique(*this); } @@ -155,7 +176,7 @@ class MultiLabelRelSetExecutor final : public RelSetExecutor { void set(ExecutionContext* context) override; - inline std::unique_ptr copy() const override { + std::unique_ptr copy() const override { return std::make_unique(*this); } diff --git a/src/include/processor/operator/physical_operator.h b/src/include/processor/operator/physical_operator.h index 2295005ff76..e8c2e2e5dde 100644 --- a/src/include/processor/operator/physical_operator.h +++ b/src/include/processor/operator/physical_operator.h @@ -54,8 +54,7 @@ enum class PhysicalOperatorType : uint8_t { SCAN_MULTI_REL_TABLES, SCAN_REL_TABLE, SEMI_MASKER, - SET_NODE_PROPERTY, - SET_REL_PROPERTY, + SET_PROPERTY, SKIP, STANDALONE_CALL, TOP_K, diff --git a/src/include/processor/plan_mapper.h b/src/include/processor/plan_mapper.h index a4c0a752630..b6abc3a23bd 100644 --- a/src/include/processor/plan_mapper.h +++ b/src/include/processor/plan_mapper.h @@ -14,10 +14,10 @@ class ClientContext; namespace binder { struct BoundDeleteInfo; -} +struct BoundSetPropertyInfo; +} // namespace binder namespace planner { -struct LogicalSetPropertyInfo; struct LogicalInsertInfo; class LogicalCopyFrom; } // namespace planner @@ -79,6 +79,7 @@ class PlanMapper { std::unique_ptr mapMarkAccumulate(planner::LogicalOperator* logicalOperator); std::unique_ptr mapDummyScan(planner::LogicalOperator* logicalOperator); std::unique_ptr mapInsert(planner::LogicalOperator* logicalOperator); + std::unique_ptr mapSetProperty(planner::LogicalOperator* logicalOperator); std::unique_ptr mapSetNodeProperty(planner::LogicalOperator* logicalOperator); std::unique_ptr mapSetRelProperty(planner::LogicalOperator* logicalOperator); std::unique_ptr mapDelete(planner::LogicalOperator* logicalOperator); @@ -183,10 +184,10 @@ class PlanMapper { const planner::Schema& outSchema) const; std::unique_ptr getRelInsertExecutor(const planner::LogicalInsertInfo* info, const planner::Schema& inSchema, const planner::Schema& outSchema) const; - std::unique_ptr getNodeSetExecutor(planner::LogicalSetPropertyInfo* info, - const planner::Schema& inSchema) const; - std::unique_ptr getRelSetExecutor(planner::LogicalSetPropertyInfo* info, - const planner::Schema& inSchema) const; + std::unique_ptr getNodeSetExecutor(const binder::BoundSetPropertyInfo& info, + const planner::Schema& schema) const; + std::unique_ptr getRelSetExecutor(const binder::BoundSetPropertyInfo& info, + const planner::Schema& schema) const; std::unique_ptr getNodeDeleteExecutor(const binder::BoundDeleteInfo& info, const planner::Schema& schema) const; std::unique_ptr getRelDeleteExecutor(const binder::BoundDeleteInfo& info, diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index e9e06ac2f48..f9bf543a7cb 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -43,6 +43,8 @@ struct NodeTableInsertState final : TableInsertState { struct NodeTableUpdateState final : TableUpdateState { common::ValueVector& nodeIDVector; + // pkVector is nullptr if we are not updating primary key column. + common::ValueVector* pkVector; NodeTableUpdateState(common::column_id_t columnID, common::ValueVector& nodeIDVector, const common::ValueVector& propertyVector) @@ -124,8 +126,6 @@ class NodeTable final : public Table { } private: - void updatePK(transaction::Transaction* transaction, common::column_id_t columnID, - common::ValueVector& nodeIDVector, const common::ValueVector& payloadVector); void insertPK(const common::ValueVector& nodeIDVector, const common::ValueVector& pkVector) const; bool scanCommitted(transaction::Transaction* transaction, NodeTableScanState& scanState); diff --git a/src/optimizer/factorization_rewriter.cpp b/src/optimizer/factorization_rewriter.cpp index a731540d2ab..b03cabdea6f 100644 --- a/src/optimizer/factorization_rewriter.cpp +++ b/src/optimizer/factorization_rewriter.cpp @@ -159,21 +159,11 @@ void FactorizationRewriter::visitFilter(planner::LogicalOperator* op) { filter->setChild(0, appendFlattens(filter->getChild(0), groupsPosToFlatten)); } -void FactorizationRewriter::visitSetNodeProperty(planner::LogicalOperator* op) { - auto setNodeProperty = (LogicalSetNodeProperty*)op; - for (auto i = 0u; i < setNodeProperty->getInfosRef().size(); ++i) { - auto groupsPosToFlatten = setNodeProperty->getGroupsPosToFlatten(i); - setNodeProperty->setChild(0, - appendFlattens(setNodeProperty->getChild(0), groupsPosToFlatten)); - } -} - -void FactorizationRewriter::visitSetRelProperty(planner::LogicalOperator* op) { - auto setRelProperty = (LogicalSetRelProperty*)op; - for (auto i = 0u; i < setRelProperty->getInfosRef().size(); ++i) { - auto groupsPosToFlatten = setRelProperty->getGroupsPosToFlatten(i); - setRelProperty->setChild(0, - appendFlattens(setRelProperty->getChild(0), groupsPosToFlatten)); +void FactorizationRewriter::visitSetProperty(planner::LogicalOperator* op) { + auto set = op->ptrCast(); + for (auto i = 0u; i < set->getInfos().size(); ++i) { + auto groupsPos = set->getGroupsPosToFlatten(i); + set->setChild(0, appendFlattens(set->getChild(0), groupsPos)); } } diff --git a/src/optimizer/logical_operator_visitor.cpp b/src/optimizer/logical_operator_visitor.cpp index b94218d7a06..1a5db5f6851 100644 --- a/src/optimizer/logical_operator_visitor.cpp +++ b/src/optimizer/logical_operator_visitor.cpp @@ -67,11 +67,8 @@ void LogicalOperatorVisitor::visitOperatorSwitch(LogicalOperator* op) { case LogicalOperatorType::FILTER: { visitFilter(op); } break; - case LogicalOperatorType::SET_NODE_PROPERTY: { - visitSetNodeProperty(op); - } break; - case LogicalOperatorType::SET_REL_PROPERTY: { - visitSetRelProperty(op); + case LogicalOperatorType::SET_PROPERTY: { + visitSetProperty(op); } break; case LogicalOperatorType::DELETE: { visitDelete(op); @@ -156,11 +153,8 @@ std::shared_ptr LogicalOperatorVisitor::visitOperatorReplaceSwi case LogicalOperatorType::FILTER: { return visitFilterReplace(op); } - case LogicalOperatorType::SET_NODE_PROPERTY: { - return visitSetNodePropertyReplace(op); - } - case LogicalOperatorType::SET_REL_PROPERTY: { - return visitSetRelPropertyReplace(op); + case LogicalOperatorType::SET_PROPERTY: { + return visitSetPropertyReplace(op); } case LogicalOperatorType::DELETE: { return visitDeleteReplace(op); diff --git a/src/optimizer/projection_push_down_optimizer.cpp b/src/optimizer/projection_push_down_optimizer.cpp index 62c4a1676f1..73cdf3379e5 100644 --- a/src/optimizer/projection_push_down_optimizer.cpp +++ b/src/optimizer/projection_push_down_optimizer.cpp @@ -162,22 +162,7 @@ void ProjectionPushDownOptimizer::visitUnwind(planner::LogicalOperator* op) { void ProjectionPushDownOptimizer::visitInsert(planner::LogicalOperator* op) { auto insert = (LogicalInsert*)op; for (auto& info : insert->getInfosRef()) { - visitInsertInfo(&info); - } -} - -void ProjectionPushDownOptimizer::visitInsertInfo(const planner::LogicalInsertInfo* info) { - if (info->tableType == common::TableType::REL) { - auto rel = (RelExpression*)info->pattern.get(); - collectExpressionsInUse(rel->getSrcNode()->getInternalID()); - collectExpressionsInUse(rel->getDstNode()->getInternalID()); - collectExpressionsInUse(rel->getInternalIDProperty()); - } - for (auto i = 0u; i < info->columnExprs.size(); ++i) { - if (info->isReturnColumnExprs[i]) { - collectExpressionsInUse(info->columnExprs[i]); - } - collectExpressionsInUse(info->columnDataExprs[i]); + visitInsertInfo(info); } } @@ -208,62 +193,36 @@ void ProjectionPushDownOptimizer::visitDelete(planner::LogicalOperator* op) { } } -// TODO(Xiyang): come back and refactor this after changing insert interface void ProjectionPushDownOptimizer::visitMerge(planner::LogicalOperator* op) { auto merge = (LogicalMerge*)op; if (merge->hasDistinctMark()) { collectExpressionsInUse(merge->getDistinctMark()); } collectExpressionsInUse(merge->getExistenceMark()); - for (auto& info : merge->getInsertNodeInfosRef()) { - visitInsertInfo(&info); - } - for (auto& info : merge->getInsertRelInfosRef()) { - visitInsertInfo(&info); + for (auto& info : merge->getInsertNodeInfos()) { + visitInsertInfo(info); } - for (auto& info : merge->getOnCreateSetNodeInfosRef()) { - auto node = (NodeExpression*)info->nodeOrRel.get(); - collectExpressionsInUse(node->getInternalID()); - collectExpressionsInUse(info->setItem.second); + for (auto& info : merge->getInsertRelInfos()) { + visitInsertInfo(info); } - for (auto& info : merge->getOnMatchSetNodeInfosRef()) { - auto node = (NodeExpression*)info->nodeOrRel.get(); - collectExpressionsInUse(node->getInternalID()); - collectExpressionsInUse(info->setItem.second); + for (auto& info : merge->getOnCreateSetNodeInfos()) { + visitSetInfo(info); } - for (auto& info : merge->getOnCreateSetRelInfosRef()) { - auto rel = (RelExpression*)info->nodeOrRel.get(); - collectExpressionsInUse(rel->getSrcNode()->getInternalID()); - collectExpressionsInUse(rel->getDstNode()->getInternalID()); - collectExpressionsInUse(rel->getInternalIDProperty()); - collectExpressionsInUse(info->setItem.second); + for (auto& info : merge->getOnMatchSetNodeInfos()) { + visitSetInfo(info); } - for (auto& info : merge->getOnMatchSetRelInfosRef()) { - auto rel = (RelExpression*)info->nodeOrRel.get(); - collectExpressionsInUse(rel->getSrcNode()->getInternalID()); - collectExpressionsInUse(rel->getDstNode()->getInternalID()); - collectExpressionsInUse(rel->getInternalIDProperty()); - collectExpressionsInUse(info->setItem.second); + for (auto& info : merge->getOnCreateSetRelInfos()) { + visitSetInfo(info); } -} - -void ProjectionPushDownOptimizer::visitSetNodeProperty(planner::LogicalOperator* op) { - auto setNodeProperty = (LogicalSetNodeProperty*)op; - for (auto& info : setNodeProperty->getInfosRef()) { - auto node = (NodeExpression*)info->nodeOrRel.get(); - collectExpressionsInUse(node->getInternalID()); - collectExpressionsInUse(info->setItem.second); + for (auto& info : merge->getOnMatchSetRelInfos()) { + visitSetInfo(info); } } -void ProjectionPushDownOptimizer::visitSetRelProperty(planner::LogicalOperator* op) { - auto setRelProperty = (LogicalSetRelProperty*)op; - for (auto& info : setRelProperty->getInfosRef()) { - auto rel = (RelExpression*)info->nodeOrRel.get(); - collectExpressionsInUse(rel->getSrcNode()->getInternalID()); - collectExpressionsInUse(rel->getDstNode()->getInternalID()); - collectExpressionsInUse(rel->getInternalIDProperty()); - collectExpressionsInUse(info->setItem.second); +void ProjectionPushDownOptimizer::visitSetProperty(planner::LogicalOperator* op) { + auto set = op->ptrCast(); + for (auto& info : set->getInfos()) { + visitSetInfo(info); } } @@ -275,6 +234,42 @@ void ProjectionPushDownOptimizer::visitCopyFrom(planner::LogicalOperator* op) { collectExpressionsInUse(copyFrom->getInfo()->offset); } +void ProjectionPushDownOptimizer::visitSetInfo(const binder::BoundSetPropertyInfo& info) { + switch (info.tableType) { + case TableType::NODE: { + auto& node = info.pattern->constCast(); + collectExpressionsInUse(node.getInternalID()); + if (info.pkExpr != nullptr) { + collectExpressionsInUse(info.pkExpr); + } + } break; + case TableType::REL: { + auto& rel = info.pattern->constCast(); + collectExpressionsInUse(rel.getSrcNode()->getInternalID()); + collectExpressionsInUse(rel.getDstNode()->getInternalID()); + collectExpressionsInUse(rel.getInternalIDProperty()); + } break; + default: + KU_UNREACHABLE; + } + collectExpressionsInUse(info.setItem.second); +} + +void ProjectionPushDownOptimizer::visitInsertInfo(const planner::LogicalInsertInfo& info) { + if (info.tableType == common::TableType::REL) { + auto& rel = info.pattern->constCast(); + collectExpressionsInUse(rel.getSrcNode()->getInternalID()); + collectExpressionsInUse(rel.getDstNode()->getInternalID()); + collectExpressionsInUse(rel.getInternalIDProperty()); + } + for (auto i = 0u; i < info.columnExprs.size(); ++i) { + if (info.isReturnColumnExprs[i]) { + collectExpressionsInUse(info.columnExprs[i]); + } + collectExpressionsInUse(info.columnDataExprs[i]); + } +} + // See comments above this class for how to collect expressions in use. void ProjectionPushDownOptimizer::collectExpressionsInUse( std::shared_ptr expression) { diff --git a/src/planner/operator/logical_operator.cpp b/src/planner/operator/logical_operator.cpp index 00d9d27e87e..7e06a3fabfb 100644 --- a/src/planner/operator/logical_operator.cpp +++ b/src/planner/operator/logical_operator.cpp @@ -96,10 +96,8 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp return "SCAN_NODE_TABLE"; case LogicalOperatorType::SEMI_MASKER: return "SEMI_MASKER"; - case LogicalOperatorType::SET_NODE_PROPERTY: - return "SET_NODE_PROPERTY"; - case LogicalOperatorType::SET_REL_PROPERTY: - return "SET_REL_PROPERTY"; + case LogicalOperatorType::SET_PROPERTY: + return "SET_PROPERTY"; case LogicalOperatorType::STANDALONE_CALL: return "STANDALONE_CALL"; case LogicalOperatorType::TABLE_FUNCTION_CALL: @@ -122,8 +120,7 @@ bool LogicalOperatorUtils::isUpdate(LogicalOperatorType type) { switch (type) { case LogicalOperatorType::INSERT: case LogicalOperatorType::DELETE: - case LogicalOperatorType::SET_NODE_PROPERTY: - case LogicalOperatorType::SET_REL_PROPERTY: + case LogicalOperatorType::SET_PROPERTY: case LogicalOperatorType::MERGE: return true; default: diff --git a/src/planner/operator/persistent/logical_merge.cpp b/src/planner/operator/persistent/logical_merge.cpp index 5bc92535de5..882e0111bed 100644 --- a/src/planner/operator/persistent/logical_merge.cpp +++ b/src/planner/operator/persistent/logical_merge.cpp @@ -37,5 +37,16 @@ f_group_pos_set LogicalMerge::getGroupsPosToFlatten() { childSchema); } +std::unique_ptr LogicalMerge::copy() { + auto merge = std::make_unique(existenceMark, distinctMark, children[0]->copy()); + merge->insertNodeInfos = copyVector(insertNodeInfos); + merge->insertRelInfos = copyVector(insertRelInfos); + merge->onCreateSetNodeInfos = copyVector(onCreateSetNodeInfos); + merge->onCreateSetRelInfos = copyVector(onCreateSetRelInfos); + merge->onMatchSetNodeInfos = copyVector(onMatchSetNodeInfos); + merge->onMatchSetRelInfos = copyVector(onMatchSetRelInfos); + return merge; +} + } // namespace planner } // namespace kuzu diff --git a/src/planner/operator/persistent/logical_set.cpp b/src/planner/operator/persistent/logical_set.cpp index d31afcd5dbc..55f7477093a 100644 --- a/src/planner/operator/persistent/logical_set.cpp +++ b/src/planner/operator/persistent/logical_set.cpp @@ -1,63 +1,57 @@ #include "planner/operator/persistent/logical_set.h" +#include "binder/expression/expression_util.h" #include "binder/expression/rel_expression.h" -#include "common/cast.h" #include "planner/operator/factorization/flatten_resolver.h" using namespace kuzu::binder; +using namespace kuzu::common; namespace kuzu { namespace planner { -std::vector> LogicalSetPropertyInfo::copy( - const std::vector>& infos) { - std::vector> infosCopy; - infosCopy.reserve(infos.size()); - for (auto& info : infos) { - infosCopy.push_back(info->copy()); - } - return infosCopy; +void LogicalSetProperty::computeFactorizedSchema() { + copyChildSchema(0); } -std::string LogicalSetNodeProperty::getExpressionsForPrinting() const { - std::string result; - for (auto& info : infos) { - result += info->setItem.first->toString() + " = " + info->setItem.second->toString() + ","; - } - return result; +void LogicalSetProperty::computeFlatSchema() { + copyChildSchema(0); } -f_group_pos_set LogicalSetNodeProperty::getGroupsPosToFlatten(uint32_t idx) { +f_group_pos_set LogicalSetProperty::getGroupsPosToFlatten(uint32_t idx) const { f_group_pos_set result; - auto node = common::ku_dynamic_cast(infos[idx]->nodeOrRel.get()); - auto rhs = infos[idx]->setItem.second; auto childSchema = children[0]->getSchema(); - result.insert(childSchema->getGroupPos(*node->getInternalID())); - for (auto groupPos : childSchema->getDependentGroupsPos(rhs)) { + auto& info = infos[idx]; + switch (getTableType()) { + case TableType::NODE: { + auto node = info.pattern->constPtrCast(); + result.insert(childSchema->getGroupPos(*node->getInternalID())); + } break; + case TableType::REL: { + auto rel = info.pattern->constPtrCast(); + result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalID())); + result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalID())); + } break; + default: + KU_UNREACHABLE; + } + for (auto& groupPos : childSchema->getDependentGroupsPos(info.setItem.second)) { result.insert(groupPos); } return factorization::FlattenAll::getGroupsPosToFlatten(result, childSchema); } -std::string LogicalSetRelProperty::getExpressionsForPrinting() const { - std::string result; - for (auto& info : infos) { - result += info->setItem.first->toString() + " = " + info->setItem.second->toString() + ","; +std::string LogicalSetProperty::getExpressionsForPrinting() const { + std::string result = ExpressionUtil::toString(infos[0].setItem); + for (auto i = 1u; i < infos.size(); ++i) { + result += ExpressionUtil::toString(infos[i].setItem); } return result; } -f_group_pos_set LogicalSetRelProperty::getGroupsPosToFlatten(uint32_t idx) { - f_group_pos_set result; - auto rel = common::ku_dynamic_cast(infos[idx]->nodeOrRel.get()); - auto rhs = infos[idx]->setItem.second; - auto childSchema = children[0]->getSchema(); - result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalID())); - result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalID())); - for (auto groupPos : childSchema->getDependentGroupsPos(rhs)) { - result.insert(groupPos); - } - return factorization::FlattenAll::getGroupsPosToFlatten(result, childSchema); +common::TableType LogicalSetProperty::getTableType() const { + KU_ASSERT(!infos.empty()); + return infos[0].tableType; } } // namespace planner diff --git a/src/planner/plan/append_set.cpp b/src/planner/plan/append_set.cpp index 364606a8b44..bb74d754029 100644 --- a/src/planner/plan/append_set.cpp +++ b/src/planner/plan/append_set.cpp @@ -7,44 +7,15 @@ using namespace kuzu::binder; namespace kuzu { namespace planner { -std::unique_ptr Planner::createLogicalSetPropertyInfo( - const BoundSetPropertyInfo* boundSetPropertyInfo) { - return std::make_unique(boundSetPropertyInfo->nodeOrRel, - boundSetPropertyInfo->setItem); -} - -void Planner::appendSetNodeProperty( - const std::vector& boundInfos, LogicalPlan& plan) { - std::vector> logicalInfos; - logicalInfos.reserve(boundInfos.size()); - for (auto& boundInfo : boundInfos) { - logicalInfos.push_back(createLogicalSetPropertyInfo(boundInfo)); - } - auto setNodeProperty = - std::make_shared(std::move(logicalInfos), plan.getLastOperator()); - for (auto i = 0u; i < boundInfos.size(); ++i) { - appendFlattens(setNodeProperty->getGroupsPosToFlatten(i), plan); - setNodeProperty->setChild(0, plan.getLastOperator()); - } - setNodeProperty->computeFactorizedSchema(); - plan.setLastOperator(setNodeProperty); -} - -void Planner::appendSetRelProperty( - const std::vector& boundInfos, LogicalPlan& plan) { - std::vector> logicalInfos; - logicalInfos.reserve(boundInfos.size()); - for (auto& boundInfo : boundInfos) { - logicalInfos.push_back(createLogicalSetPropertyInfo(boundInfo)); - } - auto setRelProperty = - std::make_shared(std::move(logicalInfos), plan.getLastOperator()); - for (auto i = 0u; i < boundInfos.size(); ++i) { - appendFlattens(setRelProperty->getGroupsPosToFlatten(i), plan); - setRelProperty->setChild(0, plan.getLastOperator()); +void Planner::appendSetProperty(const std::vector& infos, LogicalPlan& plan) { + auto set = std::make_shared(copyVector(infos), plan.getLastOperator()); + for (auto i = 0u; i < set->getInfos().size(); ++i) { + auto groupsPos = set->getGroupsPosToFlatten(i); + appendFlattens(groupsPos, plan); + set->setChild(0, plan.getLastOperator()); } - setRelProperty->computeFactorizedSchema(); - plan.setLastOperator(setRelProperty); + set->computeFactorizedSchema(); + plan.setLastOperator(std::move(set)); } } // namespace planner diff --git a/src/planner/plan/plan_update.cpp b/src/planner/plan/plan_update.cpp index 12dd726ebd9..ce727746e09 100644 --- a/src/planner/plan/plan_update.cpp +++ b/src/planner/plan/plan_update.cpp @@ -78,48 +78,38 @@ void Planner::planMergeClause(const BoundUpdatingClause* updatingClause, Logical auto existenceMark = mergeClause->getExistenceMark(); planOptionalMatch(*mergeClause->getQueryGraphCollection(), predicates, corrExprs, existenceMark, plan); - std::vector logicalInsertNodeInfos; + auto merge = + std::make_shared(existenceMark, distinctMark, plan.getLastOperator()); if (mergeClause->hasInsertNodeInfo()) { - auto boundInsertNodeInfos = mergeClause->getInsertNodeInfos(); - for (auto& info : boundInsertNodeInfos) { - logicalInsertNodeInfos.push_back(createLogicalInsertInfo(info)->copy()); + for (auto& info : mergeClause->getInsertNodeInfos()) { + merge->addInsertNodeInfo(createLogicalInsertInfo(info)->copy()); } } - std::vector logicalInsertRelInfos; if (mergeClause->hasInsertRelInfo()) { for (auto& info : mergeClause->getInsertRelInfos()) { - logicalInsertRelInfos.push_back(createLogicalInsertInfo(info)->copy()); + merge->addInsertRelInfo(createLogicalInsertInfo(info)->copy()); } } - std::vector> logicalOnCreateSetNodeInfos; if (mergeClause->hasOnCreateSetNodeInfo()) { for (auto& info : mergeClause->getOnCreateSetNodeInfos()) { - logicalOnCreateSetNodeInfos.push_back(createLogicalSetPropertyInfo(info)); + merge->addOnCreateSetNodeInfo(info.copy()); } } - std::vector> logicalOnCreateSetRelInfos; if (mergeClause->hasOnCreateSetRelInfo()) { for (auto& info : mergeClause->getOnCreateSetRelInfos()) { - logicalOnCreateSetRelInfos.push_back(createLogicalSetPropertyInfo(info)); + merge->addOnCreateSetRelInfo(info.copy()); } } - std::vector> logicalOnMatchSetNodeInfos; if (mergeClause->hasOnMatchSetNodeInfo()) { for (auto& info : mergeClause->getOnMatchSetNodeInfos()) { - logicalOnMatchSetNodeInfos.push_back(createLogicalSetPropertyInfo(info)); + merge->addOnMatchSetNodeInfo(info.copy()); } } - std::vector> logicalOnMatchSetRelInfos; if (mergeClause->hasOnMatchSetRelInfo()) { for (auto& info : mergeClause->getOnMatchSetRelInfos()) { - logicalOnMatchSetRelInfos.push_back(createLogicalSetPropertyInfo(info)); + merge->addOnMatchSetRelInfo(info.copy()); } } - auto merge = std::make_shared(existenceMark, distinctMark, - std::move(logicalInsertNodeInfos), std::move(logicalInsertRelInfos), - std::move(logicalOnCreateSetNodeInfos), std::move(logicalOnCreateSetRelInfos), - std::move(logicalOnMatchSetNodeInfos), std::move(logicalOnMatchSetRelInfos), - plan.getLastOperator()); appendFlattens(merge->getGroupsPosToFlatten(), plan); merge->setChild(0, plan.getLastOperator()); merge->computeFactorizedSchema(); @@ -128,13 +118,12 @@ void Planner::planMergeClause(const BoundUpdatingClause* updatingClause, Logical void Planner::planSetClause(const BoundUpdatingClause* updatingClause, LogicalPlan& plan) { appendAccumulate(plan); - auto setClause = - ku_dynamic_cast(updatingClause); - if (setClause->hasNodeInfo()) { - appendSetNodeProperty(setClause->getNodeInfos(), plan); + auto& setClause = updatingClause->constCast(); + if (setClause.hasNodeInfo()) { + appendSetProperty(setClause.getNodeInfos(), plan); } - if (setClause->hasRelInfo()) { - appendSetRelProperty(setClause->getRelInfos(), plan); + if (setClause.hasRelInfo()) { + appendSetProperty(setClause.getRelInfos(), plan); } } diff --git a/src/processor/map/map_delete.cpp b/src/processor/map/map_delete.cpp index 43c69c5108c..285fe368f70 100644 --- a/src/processor/map/map_delete.cpp +++ b/src/processor/map/map_delete.cpp @@ -79,7 +79,7 @@ std::unique_ptr PlanMapper::mapDeleteNode(LogicalOperator* log std::unique_ptr PlanMapper::getRelDeleteExecutor(const BoundDeleteInfo& info, const Schema& schema) const { - auto sm = clientContext->getStorageManager(); + auto storageManager = clientContext->getStorageManager(); auto& rel = info.pattern->constCast(); auto srcNodePos = getDataPos(*rel.getSrcNode()->getInternalID(), schema); auto dstNodePos = getDataPos(*rel.getDstNode()->getInternalID(), schema); @@ -87,13 +87,13 @@ std::unique_ptr PlanMapper::getRelDeleteExecutor(const BoundD if (rel.isMultiLabeled()) { common::table_id_map_t tableIDToTableMap; for (auto tableID : rel.getTableIDs()) { - auto table = sm->getTable(tableID)->ptrCast(); + auto table = storageManager->getTable(tableID)->ptrCast(); tableIDToTableMap.insert({tableID, table}); } return std::make_unique(std::move(tableIDToTableMap), srcNodePos, dstNodePos, relIDPos); } - auto table = sm->getTable(rel.getSingleTableID())->ptrCast(); + auto table = storageManager->getTable(rel.getSingleTableID())->ptrCast(); return std::make_unique(table, srcNodePos, dstNodePos, relIDPos); } diff --git a/src/processor/map/map_merge.cpp b/src/processor/map/map_merge.cpp index 11f282466f2..4a20a6d64e2 100644 --- a/src/processor/map/map_merge.cpp +++ b/src/processor/map/map_merge.cpp @@ -18,28 +18,28 @@ std::unique_ptr PlanMapper::mapMerge(planner::LogicalOperator* distinctMarkPos = getDataPos(*logicalMerge->getDistinctMark(), *inSchema); } std::vector nodeInsertExecutors; - for (auto& info : logicalMerge->getInsertNodeInfosRef()) { + for (auto& info : logicalMerge->getInsertNodeInfos()) { nodeInsertExecutors.push_back(getNodeInsertExecutor(&info, *inSchema, *outSchema)->copy()); } std::vector relInsertExecutors; - for (auto& info : logicalMerge->getInsertRelInfosRef()) { + for (auto& info : logicalMerge->getInsertRelInfos()) { relInsertExecutors.push_back(getRelInsertExecutor(&info, *inSchema, *outSchema)->copy()); } std::vector> onCreateNodeSetExecutors; - for (auto& info : logicalMerge->getOnCreateSetNodeInfosRef()) { - onCreateNodeSetExecutors.push_back(getNodeSetExecutor(info.get(), *inSchema)); + for (auto& info : logicalMerge->getOnCreateSetNodeInfos()) { + onCreateNodeSetExecutors.push_back(getNodeSetExecutor(info, *inSchema)); } std::vector> onCreateRelSetExecutors; - for (auto& info : logicalMerge->getOnCreateSetRelInfosRef()) { - onCreateRelSetExecutors.push_back(getRelSetExecutor(info.get(), *inSchema)); + for (auto& info : logicalMerge->getOnCreateSetRelInfos()) { + onCreateRelSetExecutors.push_back(getRelSetExecutor(info, *inSchema)); } std::vector> onMatchNodeSetExecutors; - for (auto& info : logicalMerge->getOnMatchSetNodeInfosRef()) { - onMatchNodeSetExecutors.push_back(getNodeSetExecutor(info.get(), *inSchema)); + for (auto& info : logicalMerge->getOnMatchSetNodeInfos()) { + onMatchNodeSetExecutors.push_back(getNodeSetExecutor(info, *inSchema)); } std::vector> onMatchRelSetExecutors; - for (auto& info : logicalMerge->getOnMatchSetRelInfosRef()) { - onMatchRelSetExecutors.push_back(getRelSetExecutor(info.get(), *inSchema)); + for (auto& info : logicalMerge->getOnMatchSetRelInfos()) { + onMatchRelSetExecutors.push_back(getRelSetExecutor(info, *inSchema)); } return std::make_unique(existenceMarkPos, distinctMarkPos, std::move(nodeInsertExecutors), std::move(relInsertExecutors), diff --git a/src/processor/map/map_set.cpp b/src/processor/map/map_set.cpp index 3870bad1557..6fe5ecafb51 100644 --- a/src/processor/map/map_set.cpp +++ b/src/processor/map/map_set.cpp @@ -16,110 +16,129 @@ using namespace kuzu::storage; namespace kuzu { namespace processor { -std::unique_ptr PlanMapper::getNodeSetExecutor( - planner::LogicalSetPropertyInfo* info, const planner::Schema& inSchema) const { - auto storageManager = clientContext->getStorageManager(); - auto catalog = clientContext->getCatalog(); - auto node = (NodeExpression*)info->nodeOrRel.get(); - auto nodeIDPos = DataPos(inSchema.getExpressionPos(*node->getInternalID())); - auto property = (PropertyExpression*)info->setItem.first.get(); - auto propertyPos = DataPos(INVALID_DATA_CHUNK_POS, INVALID_VALUE_VECTOR_POS); - if (inSchema.isExpressionInScope(*property)) { - propertyPos = DataPos(inSchema.getExpressionPos(*property)); +static ExtraNodeSetInfo getExtraNodeSetInfo(main::ClientContext* context, + common::table_id_t tableID, const PropertyExpression& propertyExpr) { + auto storageManager = context->getStorageManager(); + auto catalog = context->getCatalog(); + auto table = storageManager->getTable(tableID)->ptrCast(); + auto columnID = INVALID_COLUMN_ID; + if (propertyExpr.hasPropertyID(tableID)) { + auto propertyID = propertyExpr.getPropertyID(tableID); + auto entry = catalog->getTableCatalogEntry(context->getTx(), tableID); + columnID = entry->getColumnID(propertyID); + } + return ExtraNodeSetInfo(table, columnID); +} + +std::unique_ptr PlanMapper::getNodeSetExecutor(const BoundSetPropertyInfo& info, + const Schema& schema) const { + auto& node = info.pattern->constCast(); + auto nodeIDPos = getDataPos(*node.getInternalID(), schema); + auto& property = info.setItem.first->constCast(); + auto propertyPos = DataPos::getInvalidPos(); + if (schema.isExpressionInScope(property)) { + propertyPos = getDataPos(property, schema); + } + auto setInfo = NodeSetInfo(nodeIDPos, propertyPos); + if (info.pkExpr != nullptr) { + setInfo.pkPos = getDataPos(*info.pkExpr, schema); } - auto evaluator = ExpressionMapper::getEvaluator(info->setItem.second, &inSchema); - if (node->isMultiLabeled()) { - std::unordered_map tableIDToSetInfo; - for (auto tableID : node->getTableIDs()) { - if (!property->hasPropertyID(tableID)) { + auto evaluator = ExpressionMapper::getEvaluator(info.setItem.second, &schema); + if (node.isMultiLabeled()) { + common::table_id_map_t extraInfos; + for (auto tableID : node.getTableIDs()) { + auto extraInfo = getExtraNodeSetInfo(clientContext, tableID, property); + if (extraInfo.columnID == INVALID_COLUMN_ID) { continue; } - auto propertyID = property->getPropertyID(tableID); - auto table = ku_dynamic_cast(storageManager->getTable(tableID)); - auto columnID = catalog->getTableCatalogEntry(clientContext->getTx(), tableID) - ->getColumnID(propertyID); - tableIDToSetInfo.insert({tableID, NodeSetInfo{table, columnID}}); + extraInfos.insert({tableID, std::move(extraInfo)}); } - return std::make_unique(std::move(tableIDToSetInfo), nodeIDPos, - propertyPos, std::move(evaluator)); - } else { - auto tableID = node->getSingleTableID(); - auto table = ku_dynamic_cast(storageManager->getTable(tableID)); - auto columnID = INVALID_COLUMN_ID; - if (property->hasPropertyID(tableID)) { - auto propertyID = property->getPropertyID(tableID); - columnID = catalog->getTableCatalogEntry(clientContext->getTx(), tableID) - ->getColumnID(propertyID); - } - return std::make_unique(NodeSetInfo{table, columnID}, nodeIDPos, - propertyPos, std::move(evaluator)); + return std::make_unique(std::move(setInfo), std::move(evaluator), + std::move(extraInfos)); + } + auto extraInfo = getExtraNodeSetInfo(clientContext, node.getSingleTableID(), property); + return std::make_unique(std::move(setInfo), std::move(evaluator), + std::move(extraInfo)); +} + +std::unique_ptr PlanMapper::mapSetProperty( + planner::LogicalOperator* logicalOperator) { + auto set = logicalOperator->constPtrCast(); + switch (set->getTableType()) { + case TableType::NODE: { + return mapSetNodeProperty(logicalOperator); + } + case TableType::REL: { + return mapSetRelProperty(logicalOperator); + } + default: + KU_UNREACHABLE; } } std::unique_ptr PlanMapper::mapSetNodeProperty(LogicalOperator* logicalOperator) { - auto& logicalSetNodeProperty = (LogicalSetNodeProperty&)*logicalOperator; - auto inSchema = logicalSetNodeProperty.getChild(0)->getSchema(); + auto set = logicalOperator->constPtrCast(); + auto inSchema = set->getChild(0)->getSchema(); auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); std::vector> executors; - for (auto& info : logicalSetNodeProperty.getInfosRef()) { - executors.push_back(getNodeSetExecutor(info.get(), *inSchema)); + for (auto& info : set->getInfos()) { + executors.push_back(getNodeSetExecutor(info, *inSchema)); } return std::make_unique(std::move(executors), std::move(prevOperator), - getOperatorID(), logicalSetNodeProperty.getExpressionsForPrinting()); + getOperatorID(), set->getExpressionsForPrinting()); } -std::unique_ptr PlanMapper::getRelSetExecutor(planner::LogicalSetPropertyInfo* info, - const planner::Schema& inSchema) const { +std::unique_ptr PlanMapper::getRelSetExecutor(const BoundSetPropertyInfo& info, + const Schema& schema) const { auto storageManager = clientContext->getStorageManager(); auto catalog = clientContext->getCatalog(); - auto rel = (RelExpression*)info->nodeOrRel.get(); - auto srcNodePos = DataPos(inSchema.getExpressionPos(*rel->getSrcNode()->getInternalID())); - auto dstNodePos = DataPos(inSchema.getExpressionPos(*rel->getDstNode()->getInternalID())); - auto relIDPos = DataPos(inSchema.getExpressionPos(*rel->getInternalIDProperty())); - auto property = (PropertyExpression*)info->setItem.first.get(); - auto propertyPos = DataPos(INVALID_DATA_CHUNK_POS, INVALID_VALUE_VECTOR_POS); - if (inSchema.isExpressionInScope(*property)) { - propertyPos = DataPos(inSchema.getExpressionPos(*property)); + auto& rel = info.pattern->constCast(); + auto srcNodePos = getDataPos(*rel.getSrcNode()->getInternalID(), schema); + auto dstNodePos = getDataPos(*rel.getDstNode()->getInternalID(), schema); + auto relIDPos = getDataPos(*rel.getInternalIDProperty(), schema); + auto& property = info.setItem.first->constCast(); + auto propertyPos = DataPos::getInvalidPos(); + if (schema.isExpressionInScope(property)) { + propertyPos = getDataPos(property, schema); } - auto evaluator = ExpressionMapper::getEvaluator(info->setItem.second, &inSchema); - if (rel->isMultiLabeled()) { + auto evaluator = ExpressionMapper::getEvaluator(info.setItem.second, &schema); + if (rel.isMultiLabeled()) { std::unordered_map> tableIDToTableAndColumnID; - for (auto tableID : rel->getTableIDs()) { - if (!property->hasPropertyID(tableID)) { + for (auto tableID : rel.getTableIDs()) { + if (!property.hasPropertyID(tableID)) { continue; } - auto table = ku_dynamic_cast(storageManager->getTable(tableID)); - auto propertyID = property->getPropertyID(tableID); + auto table = storageManager->getTable(tableID)->ptrCast(); + auto propertyID = property.getPropertyID(tableID); auto columnID = catalog->getTableCatalogEntry(clientContext->getTx(), tableID) ->getColumnID(propertyID); tableIDToTableAndColumnID.insert({tableID, std::make_pair(table, columnID)}); } return std::make_unique(std::move(tableIDToTableAndColumnID), srcNodePos, dstNodePos, relIDPos, propertyPos, std::move(evaluator)); - } else { - auto tableID = rel->getSingleTableID(); - auto table = ku_dynamic_cast(storageManager->getTable(tableID)); - auto columnID = common::INVALID_COLUMN_ID; - if (property->hasPropertyID(tableID)) { - auto propertyID = property->getPropertyID(tableID); - columnID = catalog->getTableCatalogEntry(clientContext->getTx(), tableID) - ->getColumnID(propertyID); - } - return std::make_unique(table, columnID, srcNodePos, dstNodePos, - relIDPos, propertyPos, std::move(evaluator)); } + auto tableID = rel.getSingleTableID(); + auto table = storageManager->getTable(tableID)->ptrCast(); + auto columnID = common::INVALID_COLUMN_ID; + if (property.hasPropertyID(tableID)) { + auto propertyID = property.getPropertyID(tableID); + columnID = + catalog->getTableCatalogEntry(clientContext->getTx(), tableID)->getColumnID(propertyID); + } + return std::make_unique(table, columnID, srcNodePos, dstNodePos, + relIDPos, propertyPos, std::move(evaluator)); } std::unique_ptr PlanMapper::mapSetRelProperty(LogicalOperator* logicalOperator) { - auto& logicalSetRelProperty = (LogicalSetRelProperty&)*logicalOperator; - auto inSchema = logicalSetRelProperty.getChild(0)->getSchema(); + auto set = logicalOperator->constPtrCast(); + auto inSchema = set->getChild(0)->getSchema(); auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); std::vector> executors; - for (auto& info : logicalSetRelProperty.getInfosRef()) { - executors.push_back(getRelSetExecutor(info.get(), *inSchema)); + for (auto& info : set->getInfos()) { + executors.push_back(getRelSetExecutor(info, *inSchema)); } - return make_unique(std::move(executors), std::move(prevOperator), - getOperatorID(), logicalSetRelProperty.getExpressionsForPrinting()); + return std::make_unique(std::move(executors), std::move(prevOperator), + getOperatorID(), set->getExpressionsForPrinting()); } } // namespace processor diff --git a/src/processor/map/plan_mapper.cpp b/src/processor/map/plan_mapper.cpp index 54b0e1d8489..53c698311c0 100644 --- a/src/processor/map/plan_mapper.cpp +++ b/src/processor/map/plan_mapper.cpp @@ -115,11 +115,8 @@ std::unique_ptr PlanMapper::mapOperator(LogicalOperator* logic case LogicalOperatorType::INSERT: { physicalOperator = mapInsert(logicalOperator); } break; - case LogicalOperatorType::SET_NODE_PROPERTY: { - physicalOperator = mapSetNodeProperty(logicalOperator); - } break; - case LogicalOperatorType::SET_REL_PROPERTY: { - physicalOperator = mapSetRelProperty(logicalOperator); + case LogicalOperatorType::SET_PROPERTY: { + physicalOperator = mapSetProperty(logicalOperator); } break; case LogicalOperatorType::DELETE: { physicalOperator = mapDelete(logicalOperator); diff --git a/src/processor/operator/persistent/merge.cpp b/src/processor/operator/persistent/merge.cpp index 85bde6cf15a..2ff7c2b7250 100644 --- a/src/processor/operator/persistent/merge.cpp +++ b/src/processor/operator/persistent/merge.cpp @@ -81,5 +81,13 @@ bool Merge::getNextTuplesInternal(ExecutionContext* context) { return true; } +std::unique_ptr Merge::clone() { + return std::make_unique(existenceMark, distinctMark, copyVector(nodeInsertExecutors), + copyVector(relInsertExecutors), NodeSetExecutor::copy(onCreateNodeSetExecutors), + RelSetExecutor::copy(onCreateRelSetExecutors), + NodeSetExecutor::copy(onMatchNodeSetExecutors), + RelSetExecutor::copy(onMatchRelSetExecutors), children[0]->clone(), id, paramsString); +} + } // namespace processor } // namespace kuzu diff --git a/src/processor/operator/persistent/set.cpp b/src/processor/operator/persistent/set.cpp index b54eea2d481..64f00c68051 100644 --- a/src/processor/operator/persistent/set.cpp +++ b/src/processor/operator/persistent/set.cpp @@ -19,6 +19,11 @@ bool SetNodeProperty::getNextTuplesInternal(ExecutionContext* context) { return true; } +std::unique_ptr SetNodeProperty::clone() { + return std::make_unique(NodeSetExecutor::copy(executors), children[0]->clone(), + id, paramsString); +} + void SetRelProperty::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { for (auto& executor : executors) { executor->init(resultSet, context); @@ -35,5 +40,10 @@ bool SetRelProperty::getNextTuplesInternal(ExecutionContext* context) { return true; } +std::unique_ptr SetRelProperty::clone() { + return std::make_unique(RelSetExecutor::copy(executors), children[0]->clone(), + id, paramsString); +} + } // namespace processor } // namespace kuzu diff --git a/src/processor/operator/persistent/set_executor.cpp b/src/processor/operator/persistent/set_executor.cpp index d163c403b64..f21882a407d 100644 --- a/src/processor/operator/persistent/set_executor.cpp +++ b/src/processor/operator/persistent/set_executor.cpp @@ -6,22 +6,24 @@ namespace kuzu { namespace processor { void NodeSetExecutor::init(ResultSet* resultSet, ExecutionContext* context) { - nodeIDVector = resultSet->getValueVector(nodeIDPos).get(); - if (lhsVectorPos.dataChunkPos != INVALID_DATA_CHUNK_POS) { - lhsVector = resultSet->getValueVector(lhsVectorPos).get(); + nodeIDVector = resultSet->getValueVector(info.nodeIDPos).get(); + if (info.lhsPos.isValid()) { + lhsVector = resultSet->getValueVector(info.lhsPos).get(); } evaluator->init(*resultSet, context->clientContext->getMemoryManager()); rhsVector = evaluator->resultVector.get(); + if (info.pkPos.isValid()) { + pkVector = resultSet->getValueVector(info.pkPos).get(); + } } std::vector> NodeSetExecutor::copy( - const std::vector>& executors) { - std::vector> executorsCopy; - executorsCopy.reserve(executors.size()); - for (auto& executor : executors) { - executorsCopy.push_back(executor->copy()); + const std::vector>& others) { + std::vector> result; + for (auto& other : others) { + result.push_back(other->copy()); } - return executorsCopy; + return result; } static void writeToPropertyVector(ValueVector* internalIDVector, ValueVector* propertyVector, @@ -38,7 +40,7 @@ static void writeToPropertyVector(ValueVector* internalIDVector, ValueVector* pr } void SingleLabelNodeSetExecutor::set(ExecutionContext* context) { - if (setInfo.columnID == common::INVALID_COLUMN_ID) { + if (extraInfo.columnID == common::INVALID_COLUMN_ID) { if (lhsVector != nullptr) { for (auto i = 0u; i < nodeIDVector->state->getSelVector().getSelSize(); ++i) { auto lhsPos = nodeIDVector->state->getSelVector()[i]; @@ -51,9 +53,10 @@ void SingleLabelNodeSetExecutor::set(ExecutionContext* context) { KU_ASSERT(nodeIDVector->state->getSelVector().getSelSize() == 1); auto lhsPos = nodeIDVector->state->getSelVector()[0]; auto rhsPos = rhsVector->state->getSelVector()[0]; - auto updateState = std::make_unique(setInfo.columnID, + auto updateState = std::make_unique(extraInfo.columnID, *nodeIDVector, *rhsVector); - setInfo.table->update(context->clientContext->getTx(), *updateState); + updateState->pkVector = pkVector; + extraInfo.table->update(context->clientContext->getTx(), *updateState); if (lhsVector != nullptr) { writeToPropertyVector(nodeIDVector, lhsVector, lhsPos, rhsVector, rhsPos); } @@ -65,17 +68,17 @@ void MultiLabelNodeSetExecutor::set(ExecutionContext* context) { rhsVector->state->getSelVector().getSelSize() == 1); auto lhsPos = nodeIDVector->state->getSelVector()[0]; auto& nodeID = nodeIDVector->getValue(lhsPos); - if (!tableIDToSetInfo.contains(nodeID.tableID)) { + if (!extraInfos.contains(nodeID.tableID)) { if (lhsVector != nullptr) { lhsVector->setNull(lhsPos, true); } return; } auto rhsPos = rhsVector->state->getSelVector()[0]; - auto& setInfo = tableIDToSetInfo.at(nodeID.tableID); - auto updateState = std::make_unique(setInfo.columnID, + auto& extraInfo = extraInfos.at(nodeID.tableID); + auto updateState = std::make_unique(extraInfo.columnID, *nodeIDVector, *rhsVector); - setInfo.table->update(context->clientContext->getTx(), *updateState); + extraInfo.table->update(context->clientContext->getTx(), *updateState); if (lhsVector != nullptr) { KU_ASSERT(lhsVector->state->getSelVector().getSelSize() == 1); writeToPropertyVector(nodeIDVector, lhsVector, lhsPos, rhsVector, rhsPos); diff --git a/src/processor/operator/physical_operator.cpp b/src/processor/operator/physical_operator.cpp index 57d69481f9b..381e25fbd75 100644 --- a/src/processor/operator/physical_operator.cpp +++ b/src/processor/operator/physical_operator.cpp @@ -102,10 +102,8 @@ std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType ope return "SCAN_REL_TABLE"; case PhysicalOperatorType::SEMI_MASKER: return "SEMI_MASKER"; - case PhysicalOperatorType::SET_NODE_PROPERTY: - return "SET_NODE_PROPERTY"; - case PhysicalOperatorType::SET_REL_PROPERTY: - return "SET_REL_PROPERTY"; + case PhysicalOperatorType::SET_PROPERTY: + return "SET_PROPERTY"; case PhysicalOperatorType::SKIP: return "SKIP"; case PhysicalOperatorType::STANDALONE_CALL: diff --git a/src/storage/store/column.cpp b/src/storage/store/column.cpp index 4ec615d9786..3cb06d64fed 100644 --- a/src/storage/store/column.cpp +++ b/src/storage/store/column.cpp @@ -20,7 +20,7 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -static bool isPageIdxValid(page_idx_t pageIdx, const ColumnChunkMetadata& metadata) { // NOLINT +static bool isPageIdxValid(page_idx_t pageIdx, const ColumnChunkMetadata& metadata) { return (metadata.pageIdx <= pageIdx && pageIdx < metadata.pageIdx + metadata.numPages) || (pageIdx == INVALID_PAGE_IDX && metadata.compMeta.isConstant()); } @@ -252,6 +252,7 @@ void Column::batchLookup(Transaction* transaction, const offset_t* nodeOffsets, initChunkState(transaction, nodeGroupIdx, state); auto cursor = getPageCursorForOffsetInGroup(offsetInChunk, state); auto chunkMeta = metadataDA->get(nodeGroupIdx, transaction->getType()); + (void)isPageIdxValid; KU_ASSERT(isPageIdxValid(cursor.pageIdx, chunkMeta)); readFromPage(transaction, cursor.pageIdx, [&](uint8_t* frame) -> void { batchLookupFunc(frame, cursor, result, i, 1, chunkMeta.compMeta); @@ -467,8 +468,7 @@ void Column::readFromPage(Transaction* transaction, page_idx_t pageIdx, bufferManager->optimisticRead(*fileHandleToPin, pageIdxToPin, func); } -static bool sanityCheckForWrites(const ColumnChunkMetadata& metadata, - const LogicalType& dataType) { // NOLINT +static bool sanityCheckForWrites(const ColumnChunkMetadata& metadata, const LogicalType& dataType) { if (metadata.compMeta.compression == CompressionType::CONSTANT) { return metadata.numPages == 0; } @@ -489,6 +489,7 @@ void Column::append(ColumnChunk* columnChunk, ChunkState& state) { auto preScanMetadata = columnChunk->getMetadataToFlush(); auto startPageIdx = dataFH->addNewPages(preScanMetadata.numPages); state.metadata = columnChunk->flushBuffer(dataFH, startPageIdx, preScanMetadata); + (void)sanityCheckForWrites; KU_ASSERT(sanityCheckForWrites(state.metadata, dataType)); metadataDA->resize(state.nodeGroupIdx + 1); metadataDA->update(state.nodeGroupIdx, state.metadata); diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 95aae38013d..b12797bc309 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -167,8 +167,8 @@ void NodeTable::update(Transaction* transaction, TableUpdateState& updateState) KU_ASSERT(nodeUpdateState.nodeIDVector.state->getSelVector().getSelSize() == 1 && nodeUpdateState.propertyVector.state->getSelVector().getSelSize() == 1); if (nodeUpdateState.columnID == pkColumnID && pkIndex) { - updatePK(transaction, updateState.columnID, nodeUpdateState.nodeIDVector, - updateState.propertyVector); + pkIndex->delete_(nodeUpdateState.pkVector); + insertPK(nodeUpdateState.nodeIDVector, nodeUpdateState.propertyVector); } const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID, LocalStorage::NotExistAction::CREATE); @@ -250,34 +250,6 @@ void NodeTable::rollbackInMemory() { } } -void NodeTable::updatePK(Transaction* transaction, column_id_t columnID, ValueVector& nodeIDVector, - const ValueVector& payloadVector) { - const auto pkVector = - std::make_unique(getColumn(pkColumnID)->getDataType(), memoryManager); - pkVector->state = nodeIDVector.state; - auto outputVectors = std::vector{pkVector.get()}; - auto columnIDs = {columnID}; - const auto readState = - std::make_unique(&nodeIDVector, columnIDs, outputVectors); - const auto pos = nodeIDVector.state->getSelVector()[0]; - const auto nodeOffset = nodeIDVector.readNodeOffset(pos); - readState->nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset); - // TODO(Xiyang): This logic should be handled in the front-end, so when we update pk - // property, we need to scan out the original pk property first. try scan from committed - // data. - readState->source = TableScanSource::COMMITTED; - initializeScanState(transaction, *readState); - scan(transaction, *readState); - if (pkVector->state->getSelVector().getSelSize() == 0) { - // try scan from uncommitted data. - readState->source = TableScanSource::UNCOMMITTED; - initializeScanState(transaction, *readState); - scan(transaction, *readState); - } - pkIndex->delete_(pkVector.get()); - insertPK(nodeIDVector, payloadVector); -} - void NodeTable::insertPK(const ValueVector& nodeIDVector, const ValueVector& pkVector) const { for (auto i = 0u; i < nodeIDVector.state->getSelVector().getSelSize(); i++) { const auto nodeIDPos = nodeIDVector.state->getSelVector()[0]; diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index b88d7efb280..f33650c2105 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -24,6 +24,7 @@ void UndoBufferIterator::iterate(F&& callback) { UndoBuffer::UndoEntryType entryType = *reinterpret_cast(current); // Only support catalog for now. + (void)entryType; KU_ASSERT(entryType == UndoBuffer::UndoEntryType::CATALOG_ENTRY); current += sizeof(UndoBuffer::UndoEntryType); auto entrySize = *reinterpret_cast(current); @@ -49,6 +50,7 @@ void UndoBufferIterator::reverseIterate(F&& callback) { UndoBuffer::UndoEntryType entryType = *reinterpret_cast(current); // Only support catalog for now. + (void)entryType; KU_ASSERT(entryType == UndoBuffer::UndoEntryType::CATALOG_ENTRY); current += sizeof(UndoBuffer::UndoEntryType); auto entrySize = *reinterpret_cast(current); diff --git a/test/test_files/update_node/set_tinysnb.test b/test/test_files/update_node/set_tinysnb.test index 2477e3cc408..007bee6e124 100644 --- a/test/test_files/update_node/set_tinysnb.test +++ b/test/test_files/update_node/set_tinysnb.test @@ -2,6 +2,15 @@ -- +-CASE UpdatingSerialPrimaryKey +-STATEMENT CREATE NODE TABLE A(ID SERIAL, PRIMARY KEY(ID)); +---- ok +-STATEMENT CREATE (a:A); +---- ok +-STATEMENT MATCH (a:A) WHERE a.ID = 0 SET a.ID = 1; +---- error +Binder exception: Updating SERIAL primary key is not supported. + -CASE SetNodeInt64PropTest -STATEMENT MATCH (a:person) WHERE a.ID=0 SET a.age=20 + 50 ---- ok