Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Sep 19, 2024
1 parent 86e5481 commit ee55c3a
Showing 1 changed file with 113 additions and 32 deletions.
145 changes: 113 additions & 32 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8094,7 +8094,12 @@ TEST_F(HashJoinTest, nanKeys) {
}

TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) {
// Verify low selectivity / small vectors are combined to 1 vector.
// Verify low selectivity / small vectors are combined.
// Three build vectors and one probe vector are created. The duplication rate
// of keys in the build vector is 5. Half of the rows in the build vector can
// find a matching key row in the probe vector.
// The filter condition '(t1 + u1) % 3 = 0' can filter out most of the
// matching rows.
auto probeVectors = makeBatches(1, [&](auto /*unused*/) {
return makeRowVector(
{"t0", "t1"},
Expand Down Expand Up @@ -8123,54 +8128,130 @@ TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) {

createDuckDbTable("t", probeVectors);
createDuckDbTable("u", buildVectors);

core::PlanNodeId probeScanId;
core::PlanNodeId buildScanId;
core::PlanNodeId joinNodeId;
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto makePlan = [&](bool nullAware) {
return PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(probeVectors[0]->type()))
.capturePlanNodeId(probeScanId)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(buildVectors[0]->type()))
.capturePlanNodeId(buildScanId)
.planNode(),
"(t1 + u1) % 3 = 0",
{"t0", "t1", "match"},
core::JoinType::kLeftSemiProject,
nullAware)
.planNode();
};
for (auto nullAware : {false, true}) {
auto plan = makePlan(nullAware);

auto verifyJoinOutputVectorCount = [&](int expectedVectorCount,
core::JoinType joinType,
const std::string& refQuery,
bool nullAware = false,
bool flipJoinSide = false) {
std::vector<std::string> output = {"t0", "t1"};
if (joinType == core::JoinType::kLeftSemiProject) {
output.emplace_back("match");
}
auto plan = PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(probeVectors[0]->type()))
.capturePlanNodeId(probeScanId)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(buildVectors[0]->type()))
.capturePlanNodeId(buildScanId)
.planNode(),
"(t1 + u1) % 3 = 0",
output,
joinType,
nullAware)
.capturePlanNodeId(joinNodeId)
.planNode();

SplitInput splitInput = {
{probeScanId,
{exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
{buildScanId,
{exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
};

if (flipJoinSide) {
plan = flipJoinSides(plan);
}
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.planNode(plan)
.inputSplits(splitInput)
.checkSpillStats(false)
.referenceQuery(
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t")
.referenceQuery(refQuery)
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
auto stats = task->taskStats();
for (auto& pipeline : stats.pipelineStats) {
for (auto op : pipeline.operatorStats) {
if (op.operatorType == "HashProbe") {
ASSERT_EQ(op.outputVectors, 1);
}
}
}
ASSERT_EQ(
toPlanStats(task->taskStats()).at(joinNodeId).outputVectors,
expectedVectorCount);
})
.run();
};
{
SCOPED_TRACE("inner join");
verifyJoinOutputVectorCount(
1, // 2 output vectors are merged to 1 vector.
core::JoinType::kInner,
"SELECT t0, t1, FROM t, u WHERE t0 = u0 AND (t1 + u1) % 3 = 0");
}
{
SCOPED_TRACE("full join");
verifyJoinOutputVectorCount(
5, // 6 output vectors are merged to 5 vector.
core::JoinType::kFull,
"SELECT t0, t1, FROM t FULL OUTER JOIN u ON t0 = u0 AND (t1 + u1) % 3 = 0");
}
{
SCOPED_TRACE("left join");
verifyJoinOutputVectorCount(
2, // 3 output vectors are merged to 2 vectors.
core::JoinType::kLeft,
"SELECT t0, t1 FROM t LEFT JOIN u ON t0 = u0 AND (t1 + u1) % 3 = 0");
}
{
SCOPED_TRACE("right join");
verifyJoinOutputVectorCount(
2, // 3 output vectors are merged to 2 vectors.
core::JoinType::kLeft, // Flip join side.
"SELECT t0, t1 FROM t LEFT JOIN u ON t0 = u0 AND (t1 + u1) % 3 = 0",
false,
true);
}
{
SCOPED_TRACE("semi project join");
verifyJoinOutputVectorCount(
1, // 3 output vectors are merged to 1 vector.
core::JoinType::kLeftSemiProject,
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t");
verifyJoinOutputVectorCount(
1, // 3 output vectors are merged to 1 vector.
core::JoinType::kLeftSemiProject,
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t",
true);
verifyJoinOutputVectorCount(
1, // 3 output vectors are merged to 1 vector.
core::JoinType::kLeftSemiProject, // Flip join side.
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t",
false,
true);
}
{
SCOPED_TRACE("semi filter join");
verifyJoinOutputVectorCount(
1, // 2 output vectors are merged to 1 vector.
core::JoinType::kLeftSemiFilter,
"SELECT t0, t1, FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)");
verifyJoinOutputVectorCount(
1, // 2 output vectors are merged to 1 vector.
core::JoinType::kLeftSemiFilter, // Flip join side.
"SELECT t0, t1, FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)",
false,
true);
}
{
SCOPED_TRACE("anti join");
verifyJoinOutputVectorCount(
1, // 3 output vectors are merged to 1 vector.
core::JoinType::kAnti,
"SELECT t0, t1, FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)");
verifyJoinOutputVectorCount(
1, // 2 output vectors are merged to 1 vector.
core::JoinType::kAnti,
"SELECT t0, t1, FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)",
true);
}
}
} // namespace

0 comments on commit ee55c3a

Please sign in to comment.