From 98ae47b1191bc37513d3e577f97b5912a78eb77e Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Sat, 28 May 2022 11:14:33 +0800 Subject: [PATCH] [OPPRO-115] Fix semi join and anti join with filter (#15) * Fix semi join with extra filter (#1726) Summary: CC: rui-mo Pull Request resolved: https://github.com/facebookincubator/velox/pull/1726 Reviewed By: oerling Differential Revision: D36703403 Pulled By: mbasmanova fbshipit-source-id: a39788e451b544a8830c328950dd87a8bf600847 * Add support for anti join with filter (#1728) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/1728 Reviewed By: oerling Differential Revision: D36706888 Pulled By: mbasmanova fbshipit-source-id: b7d57c9ef42ee674ce26600ddbac93f45a0a4dcc Co-authored-by: Masha Basmanova --- velox/exec/HashBuild.cpp | 10 +-- velox/exec/HashProbe.cpp | 61 ++++++++++++----- velox/exec/HashProbe.h | 45 +++++++++++-- velox/exec/tests/HashJoinTest.cpp | 106 ++++++++++++++++++++++++++++++ 4 files changed, 198 insertions(+), 24 deletions(-) diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 81c9956bdb61..fa71304f63bb 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -107,15 +107,15 @@ HashBuild::HashBuild( true, // hasProbedFlag mappedMemory_); } else { - // Semi and anti join only needs to know whether there is a match. Hence, no - // need to store entries with duplicate keys. - const bool allowDuplicates = - !joinNode->isSemiJoin() && !joinNode->isAntiJoin(); + // Semi and anti join with no extra filter only needs to know whether there + // is a match. Hence, no need to store entries with duplicate keys. + const bool dropDuplicates = !joinNode->filter() && + (joinNode->isSemiJoin() || joinNode->isAntiJoin()); table_ = HashTable::createForJoin( std::move(keyHashers), dependentTypes, - allowDuplicates, + !dropDuplicates, // allowDuplicates false, // hasProbedFlag mappedMemory_); } diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 7b4c7d3f0e39..bcda93e74c42 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -260,7 +260,8 @@ void HashProbe::addInput(RowVectorPtr input) { return; } passingInputRowsInitialized_ = false; - if (isLeftJoin(joinType_) || isFullJoin(joinType_)) { + if (isLeftJoin(joinType_) || isFullJoin(joinType_) || + (isAntiJoin(joinType_) && filter_)) { // Make sure to allocate an entry in 'hits' for every input row to allow for // including rows without a match in the output. Also, make sure to // initialize all 'hits' to nullptr as HashTable::joinProbe will only @@ -421,16 +422,17 @@ RowVectorPtr HashProbe::getOutput() { return output; } - const bool isSemiOrAntiJoin = - core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_); + const bool isSemiOrAntiJoinNoFilter = + !filter_ && (core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_)); const bool emptyBuildSide = (table_->numDistinct() == 0); // Semi and anti joins are always cardinality reducing, e.g. for a given row - // of input they produce zero or 1 row of output. Therefore, we can process - // each batch of input in one go. - auto outputBatchSize = - (isSemiOrAntiJoin || emptyBuildSide) ? inputSize : outputBatchSize_; + // of input they produce zero or 1 row of output. Therefore, if there is + // no extra filter we can process each batch of input in one go. + auto outputBatchSize = (isSemiOrAntiJoinNoFilter || emptyBuildSide) + ? inputSize + : outputBatchSize_; auto mapping = initializeRowNumberMapping(rowNumberMapping_, outputBatchSize, pool()); outputRows_.resize(outputBatchSize); @@ -443,9 +445,9 @@ RowVectorPtr HashProbe::getOutput() { // rows, including ones with null join keys. std::iota(mapping.begin(), mapping.end(), 0); numOut = inputSize; - } else if (isAntiJoin(joinType_)) { - // When build side is not empty, anti join returns probe rows with no - // nulls in the join key and no match in the build side. + } else if (isAntiJoin(joinType_) && !filter_) { + // When build side is not empty, anti join without a filter returns probe + // rows with no nulls in the join key and no match in the build side. for (auto i = 0; i < inputSize; i++) { if (nonNullRows_.isValid(i) && (!activeRows_.isValid(i) || !lookup_->hits[i])) { @@ -456,7 +458,8 @@ RowVectorPtr HashProbe::getOutput() { } else { numOut = table_->listJoinResults( results_, - isLeftJoin(joinType_) || isFullJoin(joinType_), + isLeftJoin(joinType_) || isFullJoin(joinType_) || + isAntiJoin(joinType_), mapping, folly::Range(outputRows_.data(), outputRows_.size())); } @@ -470,7 +473,7 @@ RowVectorPtr HashProbe::getOutput() { numOut = evalFilter(numOut); if (!numOut) { // The filter was false on all rows. - if (isSemiOrAntiJoin) { + if (isSemiOrAntiJoinNoFilter) { input_ = nullptr; return nullptr; } @@ -484,7 +487,7 @@ RowVectorPtr HashProbe::getOutput() { fillOutput(numOut); - if (isSemiOrAntiJoin || emptyBuildSide) { + if (isSemiOrAntiJoinNoFilter || emptyBuildSide) { input_ = nullptr; } return output_; @@ -536,14 +539,42 @@ int32_t HashProbe::evalFilter(int32_t numRows) { for (auto i = 0; i < numRows; ++i) { const bool passed = !decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt(i); - leftJoinTracker_.advance(rawMapping[i], passed, addMiss); + noMatchDetector_.advance(rawMapping[i], passed, addMiss); if (passed) { outputRows_[numPassed] = outputRows_[i]; rawMapping[numPassed++] = rawMapping[i]; } } if (results_.atEnd()) { - leftJoinTracker_.finish(addMiss); + noMatchDetector_.finish(addMiss); + } + } else if (isSemiJoin(joinType_)) { + auto addLastMatch = [&](auto row) { + outputRows_[numPassed] = nullptr; + rawMapping[numPassed++] = row; + }; + for (auto i = 0; i < numRows; ++i) { + if (!decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i)) { + semiJoinTracker_.advance(rawMapping[i], addLastMatch); + } + } + if (results_.atEnd()) { + semiJoinTracker_.finish(addLastMatch); + } + } else if (isAntiJoin(joinType_)) { + // Identify probe rows with no matches. + auto addMiss = [&](auto row) { + outputRows_[numPassed] = nullptr; + rawMapping[numPassed++] = row; + }; + for (auto i = 0; i < numRows; ++i) { + const bool passed = !decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i); + noMatchDetector_.advance(rawMapping[i], passed, addMiss); + } + if (results_.atEnd()) { + noMatchDetector_.finish(addMiss); } } else { for (auto i = 0; i < numRows; ++i) { diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 99bbe4ce29b2..dcb5c50883ba 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -145,7 +145,7 @@ class HashProbe : public Operator { // Tracks probe side rows which had one or more matches on the build side, but // didn't pass the filter. - class LeftJoinTracker { + class NoMatchDetector { public: // Called for each row that the filter was evaluated on. Expects that probe // side rows with multiple matches on the build side are next to each other. @@ -182,15 +182,52 @@ class HashProbe : public Operator { bool currentRowPassed{false}; }; + // For semi join with extra filter, de-duplicates probe side rows with + // multiple matches. + class SemiJoinTracker { + public: + // Called for each row that the filter passes. Expects that probe + // side rows with multiple matches are next to each other. Calls onLastMatch + // just once for each probe side row with at least one match. + template + void advance(vector_size_t row, TOnLastMatch onLastMatch) { + if (currentRow != row) { + if (currentRow != -1) { + onLastMatch(currentRow); + } + currentRow = row; + } + } + + // Called when all rows from the current input batch were processed. Calls + // onLastMatch for the last probe row with at least one match. + template + void finish(TOnLastMatch onLastMatch) { + if (currentRow != -1) { + onLastMatch(currentRow); + } + + currentRow = -1; + } + + private: + // The last row number passed to advance for the current input batch. + vector_size_t currentRow{-1}; + }; + /// True if this is the last HashProbe operator in the pipeline. It is /// responsible for producing non-matching build-side rows for the right join. bool lastRightJoinProbe_{false}; BaseHashTable::NotProbedRowsIterator rightJoinIterator_; - /// For left join, tracks the probe side rows which had matches on the build - /// side but didn't pass the filter. - LeftJoinTracker leftJoinTracker_; + /// For left and anti join with filter, tracks the probe side rows which had + /// matches on the build side but didn't pass the filter. + NoMatchDetector noMatchDetector_; + + /// For semi join with filter, de-duplicates probe side rows with multiple + /// matches. + SemiJoinTracker semiJoinTracker_; // Keeps track of returned results between successive batches of // output for a batch of input. diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 98a25b797a39..e43aa143250d 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -532,6 +532,59 @@ TEST_F(HashJoinTest, semiJoin) { op, "SELECT t.c1 FROM t WHERE t.c0 IN (SELECT c0 FROM u WHERE c0 < 0)"); } +TEST_F(HashJoinTest, semiJoinWithFilter) { + auto leftVectors = makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(1'000, [](auto row) { return row % 11; }), + makeFlatVector(1'000, [](auto row) { return row; }), + }); + + auto rightVectors = makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(1'234, [](auto row) { return row % 5; }), + makeFlatVector(1'234, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {leftVectors}); + createDuckDbTable("u", {rightVectors}); + + auto planNodeIdGenerator = std::make_shared(); + auto op = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .planNode(), + "", + {"t0", "t1"}, + core::JoinType::kSemi) + .planNode(); + + assertQuery( + op, "SELECT t.* FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0)"); + + op = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .planNode(), + "t1 != u1", + {"t0", "t1"}, + core::JoinType::kSemi) + .planNode(); + + assertQuery( + op, + "SELECT t.* FROM t WHERE EXISTS (SELECT u0, u1 FROM u WHERE t0 = u0 AND t1 <> u1)"); +} + TEST_F(HashJoinTest, antiJoin) { auto leftVectors = makeRowVector({ makeFlatVector( @@ -604,6 +657,59 @@ TEST_F(HashJoinTest, antiJoin) { assertQuery(op, "SELECT t.c1 FROM t WHERE t.c0 NOT IN (SELECT c0 FROM u)"); } +TEST_F(HashJoinTest, antiJoinWithFilter) { + auto leftVectors = makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(1'000, [](auto row) { return row % 11; }), + makeFlatVector(1'000, [](auto row) { return row; }), + }); + + auto rightVectors = makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(1'234, [](auto row) { return row % 5; }), + makeFlatVector(1'234, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {leftVectors}); + createDuckDbTable("u", {rightVectors}); + + auto planNodeIdGenerator = std::make_shared(); + auto op = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .planNode(), + "", + {"t0", "t1"}, + core::JoinType::kAnti) + .planNode(); + + assertQuery( + op, "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0)"); + + op = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .planNode(), + "t1 != u1", + {"t0", "t1"}, + core::JoinType::kAnti) + .planNode(); + + assertQuery( + op, + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND t1 <> u1)"); +} + TEST_F(HashJoinTest, dynamicFilters) { const int32_t numSplits = 20; const int32_t numRowsProbe = 1024;