Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix disabled test for 3524 #3546

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/include/function/udf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct UDF {
(void*)(udfFunc); // Disable compiler warnings.
return [udfFunc](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result, void* /*dataPtr*/ = nullptr) -> void {
(void)params;
KU_ASSERT(params.size() == 0);
auto& resultSelVector = result.state->getSelVector();
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
Expand Down
50 changes: 12 additions & 38 deletions src/include/planner/operator/extend/logical_recursive_extend.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class LogicalRecursiveExtend : public BaseLogicalExtend {
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; }
inline RecursiveJoinType getJoinType() const { return joinType; }
inline std::shared_ptr<LogicalOperator> getRecursiveChild() const { return recursiveChild; }
void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; }
RecursiveJoinType getJoinType() const { return joinType; }
std::shared_ptr<LogicalOperator> getRecursiveChild() const { return recursiveChild; }

inline std::unique_ptr<LogicalOperator> copy() override {
std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalRecursiveExtend>(boundNode, nbrNode, rel, direction,
joinType, children[0]->copy(), recursiveChild->copy());
}
Expand All @@ -50,17 +50,17 @@ class LogicalPathPropertyProbe : public LogicalOperator {

std::string getExpressionsForPrinting() const override { return recursiveRel->toString(); }

inline std::shared_ptr<binder::RelExpression> getRel() const { return recursiveRel; }
std::shared_ptr<binder::RelExpression> getRel() const { return recursiveRel; }

inline void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; }
inline RecursiveJoinType getJoinType() const { return joinType; }
void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; }
RecursiveJoinType getJoinType() const { return joinType; }

inline void setSIP(SidewaysInfoPassing sip_) { sip = sip_; }
inline SidewaysInfoPassing getSIP() const { return sip; }
inline std::shared_ptr<LogicalOperator> getNodeChild() const { return nodeChild; }
inline std::shared_ptr<LogicalOperator> getRelChild() const { return relChild; }
void setSIP(SidewaysInfoPassing sip_) { sip = sip_; }
SidewaysInfoPassing getSIP() const { return sip; }
std::shared_ptr<LogicalOperator> getNodeChild() const { return nodeChild; }
std::shared_ptr<LogicalOperator> getRelChild() const { return relChild; }

inline std::unique_ptr<LogicalOperator> copy() override {
std::unique_ptr<LogicalOperator> copy() override {
auto nodeChildCopy = nodeChild == nullptr ? nullptr : nodeChild->copy();
auto relChildCopy = relChild == nullptr ? nullptr : relChild->copy();
return std::make_unique<LogicalPathPropertyProbe>(recursiveRel, children[0]->copy(),
Expand All @@ -75,31 +75,5 @@ class LogicalPathPropertyProbe : public LogicalOperator {
SidewaysInfoPassing sip;
};

class LogicalScanFrontier : public LogicalOperator {
public:
LogicalScanFrontier(std::shared_ptr<binder::Expression> nodeID,
std::shared_ptr<binder::Expression> nodePredicateExecFlag)
: LogicalOperator{LogicalOperatorType::SCAN_FRONTIER}, nodeID{std::move(nodeID)},
nodePredicateExecFlag{std::move(nodePredicateExecFlag)} {}

void computeFactorizedSchema() final;
void computeFlatSchema() final;

std::string getExpressionsForPrinting() const override { return std::string(); }

inline std::shared_ptr<binder::Expression> getNodeID() const { return nodeID; }
inline std::shared_ptr<binder::Expression> getNodePredicateExecutionFlag() const {
return nodePredicateExecFlag;
}

inline std::unique_ptr<LogicalOperator> copy() final {
return std::make_unique<LogicalScanFrontier>(nodeID, nodePredicateExecFlag);
}

private:
std::shared_ptr<binder::Expression> nodeID;
std::shared_ptr<binder::Expression> nodePredicateExecFlag;
};

} // namespace planner
} // namespace kuzu
5 changes: 4 additions & 1 deletion src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ enum class LogicalOperatorType : uint8_t {
PROJECTION,
RECURSIVE_EXTEND,
SCAN_FILE,
SCAN_FRONTIER,
SCAN_NODE_TABLE,
SEMI_MASKER,
SET_PROPERTY,
Expand Down Expand Up @@ -111,6 +110,10 @@ class LogicalOperator {
return common::ku_dynamic_cast<const LogicalOperator&, const TARGET&>(*this);
}
template<class TARGET>
TARGET& cast() {
return common::ku_dynamic_cast<LogicalOperator&, TARGET&>(*this);
}
template<class TARGET>
const TARGET* constPtrCast() const {
return common::ku_dynamic_cast<const LogicalOperator*, const TARGET*>(this);
}
Expand Down
37 changes: 34 additions & 3 deletions src/include/planner/operator/scan/logical_scan_node_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,35 @@
namespace kuzu {
namespace planner {

enum class LogicalScanNodeTableType : uint8_t {
SCAN = 0,
OFFSET_LOOK_UP = 1,
};

// LogicalScanNodeTable now is also the source for recursive plan. Recursive plan node predicate
// need additional variable to evaluate. I cannot think of other operator that can put it into
// recursive plan schema.
struct RecursiveJoinScanInfo {
std::shared_ptr<binder::Expression> nodePredicateExecFlag;

explicit RecursiveJoinScanInfo(std::shared_ptr<binder::Expression> expr)
: nodePredicateExecFlag{std::move(expr)} {}

std::unique_ptr<RecursiveJoinScanInfo> copy() const {
return std::make_unique<RecursiveJoinScanInfo>(nodePredicateExecFlag);
}
};

class LogicalScanNodeTable final : public LogicalOperator {
static constexpr LogicalOperatorType type_ = LogicalOperatorType::SCAN_NODE_TABLE;
static constexpr LogicalScanNodeTableType defaultScanType = LogicalScanNodeTableType::SCAN;

public:
LogicalScanNodeTable(std::shared_ptr<binder::Expression> nodeID,
std::vector<common::table_id_t> nodeTableIDs, binder::expression_vector properties)
: LogicalOperator{LogicalOperatorType::SCAN_NODE_TABLE}, nodeID{std::move(nodeID)},
: LogicalOperator{type_}, scanType{defaultScanType}, nodeID{std::move(nodeID)},
nodeTableIDs{std::move(nodeTableIDs)}, properties{std::move(properties)} {}
LogicalScanNodeTable(const LogicalScanNodeTable& other);

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -20,18 +43,26 @@ class LogicalScanNodeTable final : public LogicalOperator {
return binder::ExpressionUtil::toString(properties);
}

void setScanType(LogicalScanNodeTableType scanType_) { scanType = scanType_; }
LogicalScanNodeTableType getScanType() const { return scanType; }

std::shared_ptr<binder::Expression> getNodeID() const { return nodeID; }
std::vector<common::table_id_t> getTableIDs() const { return nodeTableIDs; }
binder::expression_vector getProperties() const { return properties; }

std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalScanNodeTable>(nodeID, nodeTableIDs, properties);
void setRecursiveJoinScanInfo(std::unique_ptr<RecursiveJoinScanInfo> info) {
recursiveJoinScanInfo = std::move(info);
}
bool hasRecursiveJoinScanInfo() const { return recursiveJoinScanInfo != nullptr; }

std::unique_ptr<LogicalOperator> copy() override;

private:
LogicalScanNodeTableType scanType;
std::shared_ptr<binder::Expression> nodeID;
std::vector<common::table_id_t> nodeTableIDs;
binder::expression_vector properties;
std::unique_ptr<RecursiveJoinScanInfo> recursiveJoinScanInfo;
};

} // namespace planner
Expand Down
38 changes: 18 additions & 20 deletions src/include/planner/operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ class FactorizationGroup {
cardinalityMultiplier{other.cardinalityMultiplier}, expressions{other.expressions},
expressionNameToPos{other.expressionNameToPos} {}

inline void setFlat() {
void setFlat() {
KU_ASSERT(!flat);
flat = true;
}
inline bool isFlat() const { return flat; }
inline void setSingleState() {
bool isFlat() const { return flat; }
void setSingleState() {
KU_ASSERT(!singleState);
singleState = true;
setFlat();
}
inline bool isSingleState() const { return singleState; }
bool isSingleState() const { return singleState; }

inline void setMultiplier(double multiplier) { cardinalityMultiplier = multiplier; }
inline double getMultiplier() const { return cardinalityMultiplier; }
void setMultiplier(double multiplier) { cardinalityMultiplier = multiplier; }
double getMultiplier() const { return cardinalityMultiplier; }

inline void insertExpression(const std::shared_ptr<binder::Expression>& expression) {
void insertExpression(const std::shared_ptr<binder::Expression>& expression) {
KU_ASSERT(!expressionNameToPos.contains(expression->getUniqueName()));
expressionNameToPos.insert({expression->getUniqueName(), expressions.size()});
expressions.push_back(expression);
}
inline binder::expression_vector getExpressions() const { return expressions; }
inline uint32_t getExpressionPos(const binder::Expression& expression) {
binder::expression_vector getExpressions() const { return expressions; }
uint32_t getExpressionPos(const binder::Expression& expression) {
KU_ASSERT(expressionNameToPos.contains(expression.getUniqueName()));
return expressionNameToPos.at(expression.getUniqueName());
}
Expand All @@ -58,20 +58,19 @@ class FactorizationGroup {

class Schema {
public:
inline size_t getNumGroups() const { return groups.size(); }
inline size_t getNumFlatGroups() const { return getNumGroups(true /* isFlat */); }
inline size_t getNumUnFlatGroups() const { return getNumGroups(false /* isFlat */); }
size_t getNumGroups() const { return groups.size(); }
size_t getNumFlatGroups() const { return getNumGroups(true /* isFlat */); }
size_t getNumUnFlatGroups() const { return getNumGroups(false /* isFlat */); }

inline FactorizationGroup* getGroup(
const std::shared_ptr<binder::Expression>& expression) const {
FactorizationGroup* getGroup(const std::shared_ptr<binder::Expression>& expression) const {
return getGroup(getGroupPos(expression->getUniqueName()));
}

inline FactorizationGroup* getGroup(const std::string& expressionName) const {
FactorizationGroup* getGroup(const std::string& expressionName) const {
return getGroup(getGroupPos(expressionName));
}

inline FactorizationGroup* getGroup(uint32_t pos) const { return groups[pos].get(); }
FactorizationGroup* getGroup(uint32_t pos) const { return groups[pos].get(); }

f_group_pos createGroup();

Expand All @@ -87,14 +86,13 @@ class Schema {

void insertToGroupAndScope(const binder::expression_vector& expressions, uint32_t groupPos);

inline f_group_pos getGroupPos(const binder::Expression& expression) const {
f_group_pos getGroupPos(const binder::Expression& expression) const {
return getGroupPos(expression.getUniqueName());
}

f_group_pos getGroupPos(const std::string& expressionName) const;

inline std::pair<f_group_pos, uint32_t> getExpressionPos(
const binder::Expression& expression) const {
std::pair<f_group_pos, uint32_t> getExpressionPos(const binder::Expression& expression) const {
auto groupPos = getGroupPos(expression);
return std::make_pair(groupPos, groups[groupPos]->getExpressionPos(expression));
}
Expand All @@ -115,7 +113,7 @@ class Schema {
std::unordered_set<f_group_pos> getDependentGroupsPos(
const std::shared_ptr<binder::Expression>& expression);

inline void clearExpressionsInScope() {
void clearExpressionsInScope() {
expressionNameToGroupPos.clear();
expressionsInScope.clear();
}
Expand Down
3 changes: 1 addition & 2 deletions src/include/processor/operator/physical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class PhysicalOperatorType : uint8_t {
INSTALL_EXTENSION,
LIMIT,
LOAD_EXTENSION,
LOOK_UP_NODE_TABLE,
MERGE,
MULTIPLICITY_REDUCER,
PARTITIONER,
Expand All @@ -49,9 +50,7 @@ enum class PhysicalOperatorType : uint8_t {
RENAME_PROPERTY,
RENAME_TABLE,
RESULT_COLLECTOR,
SCAN_FRONTIER,
SCAN_NODE_TABLE,
SCAN_MULTI_REL_TABLES,
SCAN_REL_TABLE,
SEMI_MASKER,
SET_PROPERTY,
Expand Down
10 changes: 8 additions & 2 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace kuzu {
namespace processor {

class ScanFrontier;
class LookupNodeTable;

struct RecursiveJoinSharedState {
std::vector<std::unique_ptr<NodeOffsetSemiMask>> semiMasks;
Expand All @@ -29,6 +29,8 @@ struct RecursiveJoinDataInfo {
DataPos pathLengthPos;
// Recursive join info.
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor;
DataPos recursiveSrcNodeIDPos;
DataPos recursiveNodePredicateExecFlagPos;
DataPos recursiveDstNodeIDPos;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs;
DataPos recursiveEdgeIDPos;
Expand All @@ -46,6 +48,8 @@ struct RecursiveJoinDataInfo {
dstNodeTableIDs = other.dstNodeTableIDs;
pathLengthPos = other.pathLengthPos;
localResultSetDescriptor = other.localResultSetDescriptor->copy();
recursiveSrcNodeIDPos = other.recursiveSrcNodeIDPos;
recursiveNodePredicateExecFlagPos = other.recursiveNodePredicateExecFlagPos;
recursiveDstNodeIDPos = other.recursiveDstNodeIDPos;
recursiveDstNodeTableIDs = other.recursiveDstNodeTableIDs;
recursiveEdgeIDPos = other.recursiveEdgeIDPos;
Expand All @@ -69,7 +73,9 @@ struct RecursiveJoinVectors {
common::ValueVector* pathRelsLabelDataVector = nullptr; // STRING

common::ValueVector* recursiveEdgeIDVector = nullptr;
common::ValueVector* recursiveSrcNodeIDVector = nullptr;
common::ValueVector* recursiveDstNodeIDVector = nullptr;
common::ValueVector* recursiveNodePredicateExecFlagVector = nullptr;
};

struct RecursiveJoinInfo {
Expand Down Expand Up @@ -134,7 +140,7 @@ class RecursiveJoin : public PhysicalOperator {
// Local recursive plan
std::unique_ptr<ResultSet> localResultSet;
std::unique_ptr<PhysicalOperator> recursiveRoot;
ScanFrontier* scanFrontier;
LookupNodeTable* recursiveSource;

std::unique_ptr<RecursiveJoinVectors> vectors;
std::unique_ptr<BaseBFSState> bfsState;
Expand Down
46 changes: 0 additions & 46 deletions src/include/processor/operator/recursive_extend/scan_frontier.h

This file was deleted.

Loading