Skip to content

Commit

Permalink
Fix semi join with extra filter
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed May 26, 2022
1 parent abc44dc commit 3bb3997
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 13 deletions.
10 changes: 5 additions & 5 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ HashBuild::HashBuild(
true, // hasProbedFlag
mappedMemory_);
} else {
// Semi and anti join only needs to know whether there is a match. Hence, no
// need to store entries with duplicate keys.
const bool allowDuplicates =
!joinNode->isSemiJoin() && !joinNode->isAntiJoin();
// Semi and anti join with no extra filter only needs to know whether there
// is a match. Hence, no need to store entries with duplicate keys.
const bool dropDuplicates = !joinNode->filter() &&
(joinNode->isSemiJoin() || joinNode->isAntiJoin());

table_ = HashTable<true>::createForJoin(
std::move(keyHashers),
dependentTypes,
allowDuplicates,
!dropDuplicates, // allowDuplicates
false, // hasProbedFlag
mappedMemory_);
}
Expand Down
36 changes: 28 additions & 8 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "velox/exec/HashProbe.h"
#include <iostream>
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/Task.h"
#include "velox/expression/ControlExpr.h"
Expand Down Expand Up @@ -64,6 +65,10 @@ HashProbe::HashProbe(
joinType_{joinNode->joinType()},
filterResult_(1),
outputRows_(outputBatchSize_) {
if (joinNode->isAntiJoin() && joinNode->filter()) {
VELOX_UNSUPPORTED("Anti join with a filter not supported yet.");
}

auto probeType = joinNode->sources()[0]->outputType();
auto numKeys = joinNode->leftKeys().size();
keyChannels_.reserve(numKeys);
Expand Down Expand Up @@ -421,16 +426,17 @@ RowVectorPtr HashProbe::getOutput() {
return output;
}

const bool isSemiOrAntiJoin =
core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_);
const bool isSemiOrAntiJoinNoFilter =
!filter_ && (core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_));

const bool emptyBuildSide = (table_->numDistinct() == 0);

// Semi and anti joins are always cardinality reducing, e.g. for a given row
// of input they produce zero or 1 row of output. Therefore, we can process
// each batch of input in one go.
auto outputBatchSize =
(isSemiOrAntiJoin || emptyBuildSide) ? inputSize : outputBatchSize_;
// of input they produce zero or 1 row of output. Therefore, if there is
// no extra filter we can process each batch of input in one go.
auto outputBatchSize = (isSemiOrAntiJoinNoFilter || emptyBuildSide)
? inputSize
: outputBatchSize_;
auto mapping =
initializeRowNumberMapping(rowNumberMapping_, outputBatchSize, pool());
outputRows_.resize(outputBatchSize);
Expand Down Expand Up @@ -470,7 +476,7 @@ RowVectorPtr HashProbe::getOutput() {
numOut = evalFilter(numOut);
if (!numOut) {
// The filter was false on all rows.
if (isSemiOrAntiJoin) {
if (isSemiOrAntiJoinNoFilter) {
input_ = nullptr;
return nullptr;
}
Expand All @@ -484,7 +490,7 @@ RowVectorPtr HashProbe::getOutput() {

fillOutput(numOut);

if (isSemiOrAntiJoin || emptyBuildSide) {
if (isSemiOrAntiJoinNoFilter || emptyBuildSide) {
input_ = nullptr;
}
return output_;
Expand Down Expand Up @@ -545,6 +551,20 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
if (results_.atEnd()) {
leftJoinTracker_.finish(addMiss);
}
} else if (isSemiJoin(joinType_)) {
auto addLastMatch = [&](auto row) {
outputRows_[numPassed] = nullptr;
rawMapping[numPassed++] = row;
};
for (auto i = 0; i < numRows; ++i) {
if (!decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i)) {
semiJoinTracker_.advance(rawMapping[i], addLastMatch);
}
}
if (results_.atEnd()) {
semiJoinTracker_.finish(addLastMatch);
}
} else {
for (auto i = 0; i < numRows; ++i) {
if (!decodedFilterResult_.isNullAt(i) &&
Expand Down
37 changes: 37 additions & 0 deletions velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,39 @@ class HashProbe : public Operator {
bool currentRowPassed{false};
};

// For semi join with extra filter, de-duplicates probe side rows with
// multiple matches.
class SemiJoinTracker {
public:
// Called for each row that the filter passes. Expects that probe
// side rows with multiple matches are next to each other. Calls onLastMatch
// just once for each probe side row with at least one match.
template <typename TOnLastMatch>
void advance(vector_size_t row, TOnLastMatch onLastMatch) {
if (currentRow != row) {
if (currentRow != -1) {
onLastMatch(currentRow);
}
currentRow = row;
}
}

// Called when all rows from the current input batch were processed. Calls
// onLastMatch for the last probe row with at least one match.
template <typename TOnLastMatch>
void finish(TOnLastMatch onLastMatch) {
if (currentRow != -1) {
onLastMatch(currentRow);
}

currentRow = -1;
}

private:
// The last row number passed to advance for the current input batch.
vector_size_t currentRow{-1};
};

/// True if this is the last HashProbe operator in the pipeline. It is
/// responsible for producing non-matching build-side rows for the right join.
bool lastRightJoinProbe_{false};
Expand All @@ -192,6 +225,10 @@ class HashProbe : public Operator {
/// side but didn't pass the filter.
LeftJoinTracker leftJoinTracker_;

/// For semi join with filter, de-duplicates probe side rows with multiple
/// matches.
SemiJoinTracker semiJoinTracker_;

// Keeps track of returned results between successive batches of
// output for a batch of input.
BaseHashTable::JoinResultIterator results_;
Expand Down
53 changes: 53 additions & 0 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,59 @@ TEST_F(HashJoinTest, semiJoin) {
op, "SELECT t.c1 FROM t WHERE t.c0 IN (SELECT c0 FROM u WHERE c0 < 0)");
}

TEST_F(HashJoinTest, semiJoinWithFilter) {
auto leftVectors = makeRowVector(
{"t0", "t1"},
{
makeFlatVector<int32_t>(1'000, [](auto row) { return row % 11; }),
makeFlatVector<int32_t>(1'000, [](auto row) { return row; }),
});

auto rightVectors = makeRowVector(
{"u0", "u1"},
{
makeFlatVector<int32_t>(1'234, [](auto row) { return row % 5; }),
makeFlatVector<int32_t>(1'234, [](auto row) { return row; }),
});

createDuckDbTable("t", {leftVectors});
createDuckDbTable("u", {rightVectors});

auto planNodeIdGenerator = std::make_shared<PlanNodeIdGenerator>();
auto op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"",
{"t0", "t1"},
core::JoinType::kSemi)
.planNode();

assertQuery(
op, "SELECT t.* FROM t WHERE EXISTS (SELECT c0 FROM u WHERE t0 = u0)");

op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"t1 != u1",
{"t0", "t1"},
core::JoinType::kSemi)
.planNode();

assertQuery(
op,
"SELECT t.* FROM t WHERE EXISTS (SELECT c0, c1 FROM u WHERE t0 = u0 AND t1 <> u1)");
}

TEST_F(HashJoinTest, antiJoin) {
auto leftVectors = makeRowVector({
makeFlatVector<int32_t>(
Expand Down

0 comments on commit 3bb3997

Please sign in to comment.