Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan primary key column before updating #3542

Merged
merged 2 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,19 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindSetClause(const UpdatingClause&

BoundSetPropertyInfo Binder::bindSetPropertyInfo(parser::ParsedExpression* lhs,
parser::ParsedExpression* rhs) {
auto pattern = expressionBinder.bindExpression(*lhs->getChild(0));
auto isNode = ExpressionUtil::isNodePattern(*pattern);
auto isRel = ExpressionUtil::isRelPattern(*pattern);
auto expr = expressionBinder.bindExpression(*lhs->getChild(0));
auto isNode = ExpressionUtil::isNodePattern(*expr);
auto isRel = ExpressionUtil::isRelPattern(*expr);
if (!isNode && !isRel) {
throw BinderException(
stringFormat("Cannot set expression {} with type {}. Expect node or rel pattern.",
pattern->toString(), expressionTypeToString(pattern->expressionType)));
expr->toString(), expressionTypeToString(expr->expressionType)));
}
auto patternExpr = ku_dynamic_cast<Expression*, NodeOrRelExpression*>(pattern.get());
auto& patternExpr = expr->constCast<NodeOrRelExpression>();
auto boundSetItem = bindSetItem(lhs, rhs);
// Validate not updating tables belong to RDFGraph.
auto catalog = clientContext->getCatalog();
for (auto tableID : patternExpr->getTableIDs()) {
for (auto tableID : patternExpr.getTableIDs()) {
auto tableName = catalog->getTableCatalogEntry(clientContext->getTx(), tableID)->getName();
for (auto& rdfGraphEntry : catalog->getRdfGraphEntries(clientContext->getTx())) {
if (rdfGraphEntry->isParent(tableID)) {
Expand All @@ -301,8 +302,20 @@ BoundSetPropertyInfo Binder::bindSetPropertyInfo(parser::ParsedExpression* lhs,
}
}
}
return BoundSetPropertyInfo(isNode ? UpdateTableType::NODE : UpdateTableType::REL, pattern,
std::move(boundSetItem));
if (isNode) {
auto info = BoundSetPropertyInfo(TableType::NODE, expr, boundSetItem);
auto& property = boundSetItem.first->constCast<PropertyExpression>();
for (auto id : patternExpr.getTableIDs()) {
if (property.isPrimaryKey(id)) {
info.pkExpr = boundSetItem.first;
if (info.pkExpr->dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) {
throw BinderException("Updating SERIAL primary key is not supported.");
}
}
}
return info;
}
return BoundSetPropertyInfo(TableType::REL, expr, std::move(boundSetItem));
}

expression_pair Binder::bindSetItem(parser::ParsedExpression* lhs, parser::ParsedExpression* rhs) {
Expand Down
12 changes: 6 additions & 6 deletions src/binder/query/bound_merge_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ bool BoundMergeClause::hasOnMatchSetInfo(
return false;
}

std::vector<const BoundSetPropertyInfo*> BoundMergeClause::getOnMatchSetInfos(
std::vector<BoundSetPropertyInfo> BoundMergeClause::getOnMatchSetInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<const BoundSetPropertyInfo*> result;
std::vector<BoundSetPropertyInfo> result;
for (auto& info : onMatchSetPropertyInfos) {
if (check(info)) {
result.push_back(&info);
result.push_back(info.copy());
}
}
return result;
Expand All @@ -57,12 +57,12 @@ bool BoundMergeClause::hasOnCreateSetInfo(
return false;
}

std::vector<const BoundSetPropertyInfo*> BoundMergeClause::getOnCreateSetInfos(
std::vector<BoundSetPropertyInfo> BoundMergeClause::getOnCreateSetInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<const BoundSetPropertyInfo*> result;
std::vector<BoundSetPropertyInfo> result;
for (auto& info : onCreateSetPropertyInfos) {
if (check(info)) {
result.push_back(&info);
result.push_back(info.copy());
}
}
return result;
Expand Down
6 changes: 3 additions & 3 deletions src/binder/query/bound_set_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ bool BoundSetClause::hasInfo(const std::function<bool(const BoundSetPropertyInfo
return false;
}

std::vector<const BoundSetPropertyInfo*> BoundSetClause::getInfos(
std::vector<BoundSetPropertyInfo> BoundSetClause::getInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<const BoundSetPropertyInfo*> result;
std::vector<BoundSetPropertyInfo> result;
for (auto& info : infos) {
if (check(info)) {
result.push_back(&info);
result.push_back(info.copy());
}
}
return result;
Expand Down
7 changes: 5 additions & 2 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ void PropertyCollector::visitTableFunctionCall(const BoundReadingClause& reading
}

void PropertyCollector::visitSet(const BoundUpdatingClause& updatingClause) {
auto& boundSetClause = (BoundSetClause&)updatingClause;
for (auto& info : boundSetClause.getInfosRef()) {
auto& boundSetClause = updatingClause.constCast<BoundSetClause>();
for (auto& info : boundSetClause.getInfos()) {
if (info.pkExpr != nullptr) {
properties.insert(info.pkExpr);
}
collectPropertyExpressions(info.setItem.second);
}
}
Expand Down
28 changes: 14 additions & 14 deletions src/include/binder/query/updating_clause/bound_merge_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,43 @@ class BoundMergeClause : public BoundUpdatingClause {

bool hasOnMatchSetNodeInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
return info.tableType == common::TableType::NODE;
});
}
std::vector<const BoundSetPropertyInfo*> getOnMatchSetNodeInfos() const {
std::vector<BoundSetPropertyInfo> getOnMatchSetNodeInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
return info.tableType == common::TableType::NODE;
});
}
bool hasOnMatchSetRelInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
return info.tableType == common::TableType::REL;
});
}
std::vector<const BoundSetPropertyInfo*> getOnMatchSetRelInfos() const {
std::vector<BoundSetPropertyInfo> getOnMatchSetRelInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
return info.tableType == common::TableType::REL;
});
}

bool hasOnCreateSetNodeInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
return info.tableType == common::TableType::NODE;
});
}
std::vector<const BoundSetPropertyInfo*> getOnCreateSetNodeInfos() const {
std::vector<BoundSetPropertyInfo> getOnCreateSetNodeInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
return info.tableType == common::TableType::NODE;
});
}
bool hasOnCreateSetRelInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
return info.tableType == common::TableType::REL;
});
}
std::vector<const BoundSetPropertyInfo*> getOnCreateSetRelInfos() const {
std::vector<BoundSetPropertyInfo> getOnCreateSetRelInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
return info.tableType == common::TableType::REL;
});
}

Expand All @@ -106,12 +106,12 @@ class BoundMergeClause : public BoundUpdatingClause {

bool hasOnMatchSetInfo(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<const BoundSetPropertyInfo*> getOnMatchSetInfos(
std::vector<BoundSetPropertyInfo> getOnMatchSetInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

bool hasOnCreateSetInfo(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<const BoundSetPropertyInfo*> getOnCreateSetInfos(
std::vector<BoundSetPropertyInfo> getOnCreateSetInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

private:
Expand Down
22 changes: 11 additions & 11 deletions src/include/binder/query/updating_clause/bound_set_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@ class BoundSetClause : public BoundUpdatingClause {
public:
BoundSetClause() : BoundUpdatingClause{common::ClauseType::SET} {}

inline void addInfo(BoundSetPropertyInfo info) { infos.push_back(std::move(info)); }
inline const std::vector<BoundSetPropertyInfo>& getInfosRef() { return infos; }
void addInfo(BoundSetPropertyInfo info) { infos.push_back(std::move(info)); }
const std::vector<BoundSetPropertyInfo>& getInfos() const { return infos; }

inline bool hasNodeInfo() const {
bool hasNodeInfo() const {
return hasInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
return info.tableType == common::TableType::NODE;
});
}
inline std::vector<const BoundSetPropertyInfo*> getNodeInfos() const {
std::vector<BoundSetPropertyInfo> getNodeInfos() const {
return getInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
return info.tableType == common::TableType::NODE;
});
}
inline bool hasRelInfo() const {
bool hasRelInfo() const {
return hasInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
return info.tableType == common::TableType::REL;
});
}
inline std::vector<const BoundSetPropertyInfo*> getRelInfos() const {
std::vector<BoundSetPropertyInfo> getRelInfos() const {
return getInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
return info.tableType == common::TableType::REL;
});
}

private:
bool hasInfo(const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<const BoundSetPropertyInfo*> getInfos(
std::vector<BoundSetPropertyInfo> getInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

private:
Expand Down
16 changes: 8 additions & 8 deletions src/include/binder/query/updating_clause/bound_set_info.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#pragma once

#include "binder/expression/expression.h"
#include "update_table_type.h"
#include "common/enums/table_type.h"

namespace kuzu {
namespace binder {

struct BoundSetPropertyInfo {
UpdateTableType updateTableType;
std::shared_ptr<Expression> nodeOrRel;
common::TableType tableType;
std::shared_ptr<Expression> pattern;
expression_pair setItem;
std::shared_ptr<Expression> pkExpr = nullptr;

BoundSetPropertyInfo(UpdateTableType updateTableType, std::shared_ptr<Expression> nodeOrRel,
BoundSetPropertyInfo(common::TableType tableType, std::shared_ptr<Expression> pattern,
expression_pair setItem)
: updateTableType{updateTableType}, nodeOrRel{std::move(nodeOrRel)},
setItem{std::move(setItem)} {}
: tableType{tableType}, pattern{std::move(pattern)}, setItem{std::move(setItem)} {}
EXPLICIT_COPY_DEFAULT_MOVE(BoundSetPropertyInfo);

private:
BoundSetPropertyInfo(const BoundSetPropertyInfo& other)
: updateTableType{other.updateTableType}, nodeOrRel{other.nodeOrRel},
setItem{other.setItem} {}
: tableType{other.tableType}, pattern{other.pattern}, setItem{other.setItem},
pkExpr{other.pkExpr} {}
};

} // namespace binder
Expand Down
14 changes: 0 additions & 14 deletions src/include/binder/query/updating_clause/update_table_type.h

This file was deleted.

3 changes: 1 addition & 2 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class FactorizationRewriter final : public LogicalOperatorVisitor {
void visitUnwind(planner::LogicalOperator* op) override;
void visitUnion(planner::LogicalOperator* op) override;
void visitFilter(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitSetProperty(planner::LogicalOperator* op) override;
void visitDelete(planner::LogicalOperator* op) override;
void visitInsert(planner::LogicalOperator* op) override;
void visitMerge(planner::LogicalOperator* op) override;
Expand Down
10 changes: 2 additions & 8 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,8 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitSetNodeProperty(planner::LogicalOperator* /*op*/) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetNodePropertyReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSetRelProperty(planner::LogicalOperator* /*op*/) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetRelPropertyReplace(
virtual void visitSetProperty(planner::LogicalOperator*) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetPropertyReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}
Expand Down
11 changes: 7 additions & 4 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ namespace kuzu {
namespace main {
class ClientContext;
}

namespace binder {
struct BoundSetPropertyInfo;
}
namespace planner {
struct LogicalInsertInfo;
}
Expand Down Expand Up @@ -35,14 +37,15 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
void visitProjection(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitUnwind(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitSetProperty(planner::LogicalOperator* op) override;
void visitInsert(planner::LogicalOperator* op) override;
void visitInsertInfo(const planner::LogicalInsertInfo* info);
void visitDelete(planner::LogicalOperator* op) override;
void visitMerge(planner::LogicalOperator* op) override;
void visitCopyFrom(planner::LogicalOperator* op) override;

void visitSetInfo(const binder::BoundSetPropertyInfo& info);
void visitInsertInfo(const planner::LogicalInsertInfo& info);

void collectExpressionsInUse(std::shared_ptr<binder::Expression> expression);

binder::expression_vector pruneExpressions(const binder::expression_vector& expressions);
Expand Down
3 changes: 1 addition & 2 deletions src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ enum class LogicalOperatorType : uint8_t {
SCAN_FRONTIER,
SCAN_NODE_TABLE,
SEMI_MASKER,
SET_NODE_PROPERTY,
SET_REL_PROPERTY,
SET_PROPERTY,
STANDALONE_CALL,
TABLE_FUNCTION_CALL,
TRANSACTION,
Expand Down
Loading