From ee55c3a7b47d8deeb7cb0b80ce05a7ce75d3fa5f Mon Sep 17 00:00:00 2001 From: zhli1142015 Date: Sat, 14 Sep 2024 15:15:13 +0800 Subject: [PATCH] address comments --- velox/exec/tests/HashJoinTest.cpp | 145 +++++++++++++++++++++++------- 1 file changed, 113 insertions(+), 32 deletions(-) diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index bacf8e04b498..d9ef48abeb5a 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -8094,7 +8094,12 @@ TEST_F(HashJoinTest, nanKeys) { } TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) { - // Verify low selectivity / small vectors are combined to 1 vector. + // Verify low selectivity / small vectors are combined. + // Three build vectors and one probe vector are created. The duplication rate + // of keys in the build vector is 5. Half of the rows in the build vector can + // find a matching key row in the probe vector. + // The filter condition '(t1 + u1) % 3 = 0' can filter out most of the + // matching rows. auto probeVectors = makeBatches(1, [&](auto /*unused*/) { return makeRowVector( {"t0", "t1"}, @@ -8123,29 +8128,36 @@ TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) { createDuckDbTable("t", probeVectors); createDuckDbTable("u", buildVectors); - core::PlanNodeId probeScanId; core::PlanNodeId buildScanId; + core::PlanNodeId joinNodeId; auto planNodeIdGenerator = std::make_shared(); - auto makePlan = [&](bool nullAware) { - return PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(probeVectors[0]->type())) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(buildVectors[0]->type())) - .capturePlanNodeId(buildScanId) - .planNode(), - "(t1 + u1) % 3 = 0", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject, - nullAware) - .planNode(); - }; - for (auto nullAware : {false, true}) { - auto plan = makePlan(nullAware); + + auto verifyJoinOutputVectorCount = [&](int expectedVectorCount, + core::JoinType joinType, + const std::string& refQuery, + bool nullAware = false, + bool flipJoinSide = false) { + std::vector output = {"t0", "t1"}; + if (joinType == core::JoinType::kLeftSemiProject) { + output.emplace_back("match"); + } + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(probeVectors[0]->type())) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(buildVectors[0]->type())) + .capturePlanNodeId(buildScanId) + .planNode(), + "(t1 + u1) % 3 = 0", + output, + joinType, + nullAware) + .capturePlanNodeId(joinNodeId) + .planNode(); SplitInput splitInput = { {probeScanId, @@ -8153,24 +8165,93 @@ TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) { {buildScanId, {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, }; - + if (flipJoinSide) { + plan = flipJoinSides(plan); + } HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(plan) .inputSplits(splitInput) .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") + .referenceQuery(refQuery) .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto stats = task->taskStats(); - for (auto& pipeline : stats.pipelineStats) { - for (auto op : pipeline.operatorStats) { - if (op.operatorType == "HashProbe") { - ASSERT_EQ(op.outputVectors, 1); - } - } - } + ASSERT_EQ( + toPlanStats(task->taskStats()).at(joinNodeId).outputVectors, + expectedVectorCount); }) .run(); + }; + { + SCOPED_TRACE("inner join"); + verifyJoinOutputVectorCount( + 1, // 2 output vectors are merged to 1 vector. + core::JoinType::kInner, + "SELECT t0, t1, FROM t, u WHERE t0 = u0 AND (t1 + u1) % 3 = 0"); + } + { + SCOPED_TRACE("full join"); + verifyJoinOutputVectorCount( + 5, // 6 output vectors are merged to 5 vector. + core::JoinType::kFull, + "SELECT t0, t1, FROM t FULL OUTER JOIN u ON t0 = u0 AND (t1 + u1) % 3 = 0"); + } + { + SCOPED_TRACE("left join"); + verifyJoinOutputVectorCount( + 2, // 3 output vectors are merged to 2 vectors. + core::JoinType::kLeft, + "SELECT t0, t1 FROM t LEFT JOIN u ON t0 = u0 AND (t1 + u1) % 3 = 0"); + } + { + SCOPED_TRACE("right join"); + verifyJoinOutputVectorCount( + 2, // 3 output vectors are merged to 2 vectors. + core::JoinType::kLeft, // Flip join side. + "SELECT t0, t1 FROM t LEFT JOIN u ON t0 = u0 AND (t1 + u1) % 3 = 0", + false, + true); + } + { + SCOPED_TRACE("semi project join"); + verifyJoinOutputVectorCount( + 1, // 3 output vectors are merged to 1 vector. + core::JoinType::kLeftSemiProject, + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t"); + verifyJoinOutputVectorCount( + 1, // 3 output vectors are merged to 1 vector. + core::JoinType::kLeftSemiProject, + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t", + true); + verifyJoinOutputVectorCount( + 1, // 3 output vectors are merged to 1 vector. + core::JoinType::kLeftSemiProject, // Flip join side. + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t", + false, + true); + } + { + SCOPED_TRACE("semi filter join"); + verifyJoinOutputVectorCount( + 1, // 2 output vectors are merged to 1 vector. + core::JoinType::kLeftSemiFilter, + "SELECT t0, t1, FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)"); + verifyJoinOutputVectorCount( + 1, // 2 output vectors are merged to 1 vector. + core::JoinType::kLeftSemiFilter, // Flip join side. + "SELECT t0, t1, FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)", + false, + true); + } + { + SCOPED_TRACE("anti join"); + verifyJoinOutputVectorCount( + 1, // 3 output vectors are merged to 1 vector. + core::JoinType::kAnti, + "SELECT t0, t1, FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)"); + verifyJoinOutputVectorCount( + 1, // 2 output vectors are merged to 1 vector. + core::JoinType::kAnti, + "SELECT t0, t1, FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND (t1 + u1) % 3 = 0)", + true); } } } // namespace