From 3bb399757151bb7026f895d9a1d784705449dd3f Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Thu, 26 May 2022 10:05:34 -0400 Subject: [PATCH] Fix semi join with extra filter --- velox/exec/HashBuild.cpp | 10 +++--- velox/exec/HashProbe.cpp | 36 ++++++++++++++++----- velox/exec/HashProbe.h | 37 +++++++++++++++++++++ velox/exec/tests/HashJoinTest.cpp | 53 +++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 81c9956bdb61..fa71304f63bb 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -107,15 +107,15 @@ HashBuild::HashBuild( true, // hasProbedFlag mappedMemory_); } else { - // Semi and anti join only needs to know whether there is a match. Hence, no - // need to store entries with duplicate keys. - const bool allowDuplicates = - !joinNode->isSemiJoin() && !joinNode->isAntiJoin(); + // Semi and anti join with no extra filter only needs to know whether there + // is a match. Hence, no need to store entries with duplicate keys. + const bool dropDuplicates = !joinNode->filter() && + (joinNode->isSemiJoin() || joinNode->isAntiJoin()); table_ = HashTable::createForJoin( std::move(keyHashers), dependentTypes, - allowDuplicates, + !dropDuplicates, // allowDuplicates false, // hasProbedFlag mappedMemory_); } diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 7b4c7d3f0e39..900e8fe009f8 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/HashProbe.h" +#include #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/ControlExpr.h" @@ -64,6 +65,10 @@ HashProbe::HashProbe( joinType_{joinNode->joinType()}, filterResult_(1), outputRows_(outputBatchSize_) { + if (joinNode->isAntiJoin() && joinNode->filter()) { + VELOX_UNSUPPORTED("Anti join with a filter not supported yet."); + } + auto probeType = joinNode->sources()[0]->outputType(); auto numKeys = joinNode->leftKeys().size(); keyChannels_.reserve(numKeys); @@ -421,16 +426,17 @@ RowVectorPtr HashProbe::getOutput() { return output; } - const bool isSemiOrAntiJoin = - core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_); + const bool isSemiOrAntiJoinNoFilter = + !filter_ && (core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_)); const bool emptyBuildSide = (table_->numDistinct() == 0); // Semi and anti joins are always cardinality reducing, e.g. for a given row - // of input they produce zero or 1 row of output. Therefore, we can process - // each batch of input in one go. - auto outputBatchSize = - (isSemiOrAntiJoin || emptyBuildSide) ? inputSize : outputBatchSize_; + // of input they produce zero or 1 row of output. Therefore, if there is + // no extra filter we can process each batch of input in one go. + auto outputBatchSize = (isSemiOrAntiJoinNoFilter || emptyBuildSide) + ? inputSize + : outputBatchSize_; auto mapping = initializeRowNumberMapping(rowNumberMapping_, outputBatchSize, pool()); outputRows_.resize(outputBatchSize); @@ -470,7 +476,7 @@ RowVectorPtr HashProbe::getOutput() { numOut = evalFilter(numOut); if (!numOut) { // The filter was false on all rows. - if (isSemiOrAntiJoin) { + if (isSemiOrAntiJoinNoFilter) { input_ = nullptr; return nullptr; } @@ -484,7 +490,7 @@ RowVectorPtr HashProbe::getOutput() { fillOutput(numOut); - if (isSemiOrAntiJoin || emptyBuildSide) { + if (isSemiOrAntiJoinNoFilter || emptyBuildSide) { input_ = nullptr; } return output_; @@ -545,6 +551,20 @@ int32_t HashProbe::evalFilter(int32_t numRows) { if (results_.atEnd()) { leftJoinTracker_.finish(addMiss); } + } else if (isSemiJoin(joinType_)) { + auto addLastMatch = [&](auto row) { + outputRows_[numPassed] = nullptr; + rawMapping[numPassed++] = row; + }; + for (auto i = 0; i < numRows; ++i) { + if (!decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i)) { + semiJoinTracker_.advance(rawMapping[i], addLastMatch); + } + } + if (results_.atEnd()) { + semiJoinTracker_.finish(addLastMatch); + } } else { for (auto i = 0; i < numRows; ++i) { if (!decodedFilterResult_.isNullAt(i) && diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 99bbe4ce29b2..6fb8e42c89e4 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -182,6 +182,39 @@ class HashProbe : public Operator { bool currentRowPassed{false}; }; + // For semi join with extra filter, de-duplicates probe side rows with + // multiple matches. + class SemiJoinTracker { + public: + // Called for each row that the filter passes. Expects that probe + // side rows with multiple matches are next to each other. Calls onLastMatch + // just once for each probe side row with at least one match. + template + void advance(vector_size_t row, TOnLastMatch onLastMatch) { + if (currentRow != row) { + if (currentRow != -1) { + onLastMatch(currentRow); + } + currentRow = row; + } + } + + // Called when all rows from the current input batch were processed. Calls + // onLastMatch for the last probe row with at least one match. + template + void finish(TOnLastMatch onLastMatch) { + if (currentRow != -1) { + onLastMatch(currentRow); + } + + currentRow = -1; + } + + private: + // The last row number passed to advance for the current input batch. + vector_size_t currentRow{-1}; + }; + /// True if this is the last HashProbe operator in the pipeline. It is /// responsible for producing non-matching build-side rows for the right join. bool lastRightJoinProbe_{false}; @@ -192,6 +225,10 @@ class HashProbe : public Operator { /// side but didn't pass the filter. LeftJoinTracker leftJoinTracker_; + /// For semi join with filter, de-duplicates probe side rows with multiple + /// matches. + SemiJoinTracker semiJoinTracker_; + // Keeps track of returned results between successive batches of // output for a batch of input. BaseHashTable::JoinResultIterator results_; diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 7b38c528baa3..702d281582b5 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -534,6 +534,59 @@ TEST_F(HashJoinTest, semiJoin) { op, "SELECT t.c1 FROM t WHERE t.c0 IN (SELECT c0 FROM u WHERE c0 < 0)"); } +TEST_F(HashJoinTest, semiJoinWithFilter) { + auto leftVectors = makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(1'000, [](auto row) { return row % 11; }), + makeFlatVector(1'000, [](auto row) { return row; }), + }); + + auto rightVectors = makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(1'234, [](auto row) { return row % 5; }), + makeFlatVector(1'234, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {leftVectors}); + createDuckDbTable("u", {rightVectors}); + + auto planNodeIdGenerator = std::make_shared(); + auto op = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .planNode(), + "", + {"t0", "t1"}, + core::JoinType::kSemi) + .planNode(); + + assertQuery( + op, "SELECT t.* FROM t WHERE EXISTS (SELECT c0 FROM u WHERE t0 = u0)"); + + op = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .planNode(), + "t1 != u1", + {"t0", "t1"}, + core::JoinType::kSemi) + .planNode(); + + assertQuery( + op, + "SELECT t.* FROM t WHERE EXISTS (SELECT c0, c1 FROM u WHERE t0 = u0 AND t1 <> u1)"); +} + TEST_F(HashJoinTest, antiJoin) { auto leftVectors = makeRowVector({ makeFlatVector(