Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for anti join with filter #1728

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ HashProbe::HashProbe(
joinType_{joinNode->joinType()},
filterResult_(1),
outputRows_(outputBatchSize_) {
if (joinNode->isAntiJoin() && joinNode->filter()) {
VELOX_UNSUPPORTED("Anti join with a filter not supported yet.");
}

auto probeType = joinNode->sources()[0]->outputType();
auto numKeys = joinNode->leftKeys().size();
keyChannels_.reserve(numKeys);
Expand Down Expand Up @@ -264,7 +260,8 @@ void HashProbe::addInput(RowVectorPtr input) {
return;
}
passingInputRowsInitialized_ = false;
if (isLeftJoin(joinType_) || isFullJoin(joinType_)) {
if (isLeftJoin(joinType_) || isFullJoin(joinType_) ||
(isAntiJoin(joinType_) && filter_)) {
// Make sure to allocate an entry in 'hits' for every input row to allow for
// including rows without a match in the output. Also, make sure to
// initialize all 'hits' to nullptr as HashTable::joinProbe will only
Expand Down Expand Up @@ -448,9 +445,9 @@ RowVectorPtr HashProbe::getOutput() {
// rows, including ones with null join keys.
std::iota(mapping.begin(), mapping.end(), 0);
numOut = inputSize;
} else if (isAntiJoin(joinType_)) {
// When build side is not empty, anti join returns probe rows with no
// nulls in the join key and no match in the build side.
} else if (isAntiJoin(joinType_) && !filter_) {
// When build side is not empty, anti join without a filter returns probe
// rows with no nulls in the join key and no match in the build side.
for (auto i = 0; i < inputSize; i++) {
if (nonNullRows_.isValid(i) &&
(!activeRows_.isValid(i) || !lookup_->hits[i])) {
Expand All @@ -461,7 +458,8 @@ RowVectorPtr HashProbe::getOutput() {
} else {
numOut = table_->listJoinResults(
results_,
isLeftJoin(joinType_) || isFullJoin(joinType_),
isLeftJoin(joinType_) || isFullJoin(joinType_) ||
isAntiJoin(joinType_),
mapping,
folly::Range(outputRows_.data(), outputRows_.size()));
}
Expand Down Expand Up @@ -541,14 +539,14 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
for (auto i = 0; i < numRows; ++i) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);
leftJoinTracker_.advance(rawMapping[i], passed, addMiss);
noMatchDetector_.advance(rawMapping[i], passed, addMiss);
if (passed) {
outputRows_[numPassed] = outputRows_[i];
rawMapping[numPassed++] = rawMapping[i];
}
}
if (results_.atEnd()) {
leftJoinTracker_.finish(addMiss);
noMatchDetector_.finish(addMiss);
}
} else if (isSemiJoin(joinType_)) {
auto addLastMatch = [&](auto row) {
Expand All @@ -564,6 +562,20 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
if (results_.atEnd()) {
semiJoinTracker_.finish(addLastMatch);
}
} else if (isAntiJoin(joinType_)) {
// Identify probe rows with no matches.
auto addMiss = [&](auto row) {
outputRows_[numPassed] = nullptr;
rawMapping[numPassed++] = row;
};
for (auto i = 0; i < numRows; ++i) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);
noMatchDetector_.advance(rawMapping[i], passed, addMiss);
}
if (results_.atEnd()) {
noMatchDetector_.finish(addMiss);
}
} else {
for (auto i = 0; i < numRows; ++i) {
if (!decodedFilterResult_.isNullAt(i) &&
Expand Down
8 changes: 4 additions & 4 deletions velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class HashProbe : public Operator {

// Tracks probe side rows which had one or more matches on the build side, but
// didn't pass the filter.
class LeftJoinTracker {
class NoMatchDetector {
public:
// Called for each row that the filter was evaluated on. Expects that probe
// side rows with multiple matches on the build side are next to each other.
Expand Down Expand Up @@ -221,9 +221,9 @@ class HashProbe : public Operator {

BaseHashTable::NotProbedRowsIterator rightJoinIterator_;

/// For left join, tracks the probe side rows which had matches on the build
/// side but didn't pass the filter.
LeftJoinTracker leftJoinTracker_;
/// For left and anti join with filter, tracks the probe side rows which had
/// matches on the build side but didn't pass the filter.
NoMatchDetector noMatchDetector_;

/// For semi join with filter, de-duplicates probe side rows with multiple
/// matches.
Expand Down
53 changes: 53 additions & 0 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,59 @@ TEST_F(HashJoinTest, antiJoin) {
assertQuery(op, "SELECT t.c1 FROM t WHERE t.c0 NOT IN (SELECT c0 FROM u)");
}

TEST_F(HashJoinTest, antiJoinWithFilter) {
auto leftVectors = makeRowVector(
{"t0", "t1"},
{
makeFlatVector<int32_t>(1'000, [](auto row) { return row % 11; }),
makeFlatVector<int32_t>(1'000, [](auto row) { return row; }),
});

auto rightVectors = makeRowVector(
{"u0", "u1"},
{
makeFlatVector<int32_t>(1'234, [](auto row) { return row % 5; }),
makeFlatVector<int32_t>(1'234, [](auto row) { return row; }),
});

createDuckDbTable("t", {leftVectors});
createDuckDbTable("u", {rightVectors});

auto planNodeIdGenerator = std::make_shared<PlanNodeIdGenerator>();
auto op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"",
{"t0", "t1"},
core::JoinType::kAnti)
.planNode();

assertQuery(
op, "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0)");

op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"t1 != u1",
{"t0", "t1"},
core::JoinType::kAnti)
.planNode();

assertQuery(
op,
"SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND t1 <> u1)");
}

TEST_F(HashJoinTest, dynamicFilters) {
const int32_t numSplits = 20;
const int32_t numRowsProbe = 1024;
Expand Down