Skip to content

Commit

Permalink
[OPPRO-115] Fix semi join and anti join with filter (facebookincubato…
Browse files Browse the repository at this point in the history
…r#15)

* Fix semi join with extra filter (facebookincubator#1726)

Summary:
CC: rui-mo

Pull Request resolved: facebookincubator#1726

Reviewed By: oerling

Differential Revision: D36703403

Pulled By: mbasmanova

fbshipit-source-id: a39788e451b544a8830c328950dd87a8bf600847

* Add support for anti join with filter (facebookincubator#1728)

Summary: Pull Request resolved: facebookincubator#1728

Reviewed By: oerling

Differential Revision: D36706888

Pulled By: mbasmanova

fbshipit-source-id: b7d57c9ef42ee674ce26600ddbac93f45a0a4dcc

Co-authored-by: Masha Basmanova <mbasmanova@fb.com>
  • Loading branch information
rui-mo and mbasmanova authored May 28, 2022
1 parent 9bb71bf commit 98ae47b
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 24 deletions.
10 changes: 5 additions & 5 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ HashBuild::HashBuild(
true, // hasProbedFlag
mappedMemory_);
} else {
// Semi and anti join only needs to know whether there is a match. Hence, no
// need to store entries with duplicate keys.
const bool allowDuplicates =
!joinNode->isSemiJoin() && !joinNode->isAntiJoin();
// Semi and anti join with no extra filter only needs to know whether there
// is a match. Hence, no need to store entries with duplicate keys.
const bool dropDuplicates = !joinNode->filter() &&
(joinNode->isSemiJoin() || joinNode->isAntiJoin());

table_ = HashTable<true>::createForJoin(
std::move(keyHashers),
dependentTypes,
allowDuplicates,
!dropDuplicates, // allowDuplicates
false, // hasProbedFlag
mappedMemory_);
}
Expand Down
61 changes: 46 additions & 15 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ void HashProbe::addInput(RowVectorPtr input) {
return;
}
passingInputRowsInitialized_ = false;
if (isLeftJoin(joinType_) || isFullJoin(joinType_)) {
if (isLeftJoin(joinType_) || isFullJoin(joinType_) ||
(isAntiJoin(joinType_) && filter_)) {
// Make sure to allocate an entry in 'hits' for every input row to allow for
// including rows without a match in the output. Also, make sure to
// initialize all 'hits' to nullptr as HashTable::joinProbe will only
Expand Down Expand Up @@ -421,16 +422,17 @@ RowVectorPtr HashProbe::getOutput() {
return output;
}

const bool isSemiOrAntiJoin =
core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_);
const bool isSemiOrAntiJoinNoFilter =
!filter_ && (core::isSemiJoin(joinType_) || core::isAntiJoin(joinType_));

const bool emptyBuildSide = (table_->numDistinct() == 0);

// Semi and anti joins are always cardinality reducing, e.g. for a given row
// of input they produce zero or 1 row of output. Therefore, we can process
// each batch of input in one go.
auto outputBatchSize =
(isSemiOrAntiJoin || emptyBuildSide) ? inputSize : outputBatchSize_;
// of input they produce zero or 1 row of output. Therefore, if there is
// no extra filter we can process each batch of input in one go.
auto outputBatchSize = (isSemiOrAntiJoinNoFilter || emptyBuildSide)
? inputSize
: outputBatchSize_;
auto mapping =
initializeRowNumberMapping(rowNumberMapping_, outputBatchSize, pool());
outputRows_.resize(outputBatchSize);
Expand All @@ -443,9 +445,9 @@ RowVectorPtr HashProbe::getOutput() {
// rows, including ones with null join keys.
std::iota(mapping.begin(), mapping.end(), 0);
numOut = inputSize;
} else if (isAntiJoin(joinType_)) {
// When build side is not empty, anti join returns probe rows with no
// nulls in the join key and no match in the build side.
} else if (isAntiJoin(joinType_) && !filter_) {
// When build side is not empty, anti join without a filter returns probe
// rows with no nulls in the join key and no match in the build side.
for (auto i = 0; i < inputSize; i++) {
if (nonNullRows_.isValid(i) &&
(!activeRows_.isValid(i) || !lookup_->hits[i])) {
Expand All @@ -456,7 +458,8 @@ RowVectorPtr HashProbe::getOutput() {
} else {
numOut = table_->listJoinResults(
results_,
isLeftJoin(joinType_) || isFullJoin(joinType_),
isLeftJoin(joinType_) || isFullJoin(joinType_) ||
isAntiJoin(joinType_),
mapping,
folly::Range(outputRows_.data(), outputRows_.size()));
}
Expand All @@ -470,7 +473,7 @@ RowVectorPtr HashProbe::getOutput() {
numOut = evalFilter(numOut);
if (!numOut) {
// The filter was false on all rows.
if (isSemiOrAntiJoin) {
if (isSemiOrAntiJoinNoFilter) {
input_ = nullptr;
return nullptr;
}
Expand All @@ -484,7 +487,7 @@ RowVectorPtr HashProbe::getOutput() {

fillOutput(numOut);

if (isSemiOrAntiJoin || emptyBuildSide) {
if (isSemiOrAntiJoinNoFilter || emptyBuildSide) {
input_ = nullptr;
}
return output_;
Expand Down Expand Up @@ -536,14 +539,42 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
for (auto i = 0; i < numRows; ++i) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);
leftJoinTracker_.advance(rawMapping[i], passed, addMiss);
noMatchDetector_.advance(rawMapping[i], passed, addMiss);
if (passed) {
outputRows_[numPassed] = outputRows_[i];
rawMapping[numPassed++] = rawMapping[i];
}
}
if (results_.atEnd()) {
leftJoinTracker_.finish(addMiss);
noMatchDetector_.finish(addMiss);
}
} else if (isSemiJoin(joinType_)) {
auto addLastMatch = [&](auto row) {
outputRows_[numPassed] = nullptr;
rawMapping[numPassed++] = row;
};
for (auto i = 0; i < numRows; ++i) {
if (!decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i)) {
semiJoinTracker_.advance(rawMapping[i], addLastMatch);
}
}
if (results_.atEnd()) {
semiJoinTracker_.finish(addLastMatch);
}
} else if (isAntiJoin(joinType_)) {
// Identify probe rows with no matches.
auto addMiss = [&](auto row) {
outputRows_[numPassed] = nullptr;
rawMapping[numPassed++] = row;
};
for (auto i = 0; i < numRows; ++i) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);
noMatchDetector_.advance(rawMapping[i], passed, addMiss);
}
if (results_.atEnd()) {
noMatchDetector_.finish(addMiss);
}
} else {
for (auto i = 0; i < numRows; ++i) {
Expand Down
45 changes: 41 additions & 4 deletions velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class HashProbe : public Operator {

// Tracks probe side rows which had one or more matches on the build side, but
// didn't pass the filter.
class LeftJoinTracker {
class NoMatchDetector {
public:
// Called for each row that the filter was evaluated on. Expects that probe
// side rows with multiple matches on the build side are next to each other.
Expand Down Expand Up @@ -182,15 +182,52 @@ class HashProbe : public Operator {
bool currentRowPassed{false};
};

// For semi join with extra filter, de-duplicates probe side rows with
// multiple matches.
class SemiJoinTracker {
public:
// Called for each row that the filter passes. Expects that probe
// side rows with multiple matches are next to each other. Calls onLastMatch
// just once for each probe side row with at least one match.
template <typename TOnLastMatch>
void advance(vector_size_t row, TOnLastMatch onLastMatch) {
if (currentRow != row) {
if (currentRow != -1) {
onLastMatch(currentRow);
}
currentRow = row;
}
}

// Called when all rows from the current input batch were processed. Calls
// onLastMatch for the last probe row with at least one match.
template <typename TOnLastMatch>
void finish(TOnLastMatch onLastMatch) {
if (currentRow != -1) {
onLastMatch(currentRow);
}

currentRow = -1;
}

private:
// The last row number passed to advance for the current input batch.
vector_size_t currentRow{-1};
};

/// True if this is the last HashProbe operator in the pipeline. It is
/// responsible for producing non-matching build-side rows for the right join.
bool lastRightJoinProbe_{false};

BaseHashTable::NotProbedRowsIterator rightJoinIterator_;

/// For left join, tracks the probe side rows which had matches on the build
/// side but didn't pass the filter.
LeftJoinTracker leftJoinTracker_;
/// For left and anti join with filter, tracks the probe side rows which had
/// matches on the build side but didn't pass the filter.
NoMatchDetector noMatchDetector_;

/// For semi join with filter, de-duplicates probe side rows with multiple
/// matches.
SemiJoinTracker semiJoinTracker_;

// Keeps track of returned results between successive batches of
// output for a batch of input.
Expand Down
106 changes: 106 additions & 0 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,59 @@ TEST_F(HashJoinTest, semiJoin) {
op, "SELECT t.c1 FROM t WHERE t.c0 IN (SELECT c0 FROM u WHERE c0 < 0)");
}

TEST_F(HashJoinTest, semiJoinWithFilter) {
auto leftVectors = makeRowVector(
{"t0", "t1"},
{
makeFlatVector<int32_t>(1'000, [](auto row) { return row % 11; }),
makeFlatVector<int32_t>(1'000, [](auto row) { return row; }),
});

auto rightVectors = makeRowVector(
{"u0", "u1"},
{
makeFlatVector<int32_t>(1'234, [](auto row) { return row % 5; }),
makeFlatVector<int32_t>(1'234, [](auto row) { return row; }),
});

createDuckDbTable("t", {leftVectors});
createDuckDbTable("u", {rightVectors});

auto planNodeIdGenerator = std::make_shared<PlanNodeIdGenerator>();
auto op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"",
{"t0", "t1"},
core::JoinType::kSemi)
.planNode();

assertQuery(
op, "SELECT t.* FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0)");

op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"t1 != u1",
{"t0", "t1"},
core::JoinType::kSemi)
.planNode();

assertQuery(
op,
"SELECT t.* FROM t WHERE EXISTS (SELECT u0, u1 FROM u WHERE t0 = u0 AND t1 <> u1)");
}

TEST_F(HashJoinTest, antiJoin) {
auto leftVectors = makeRowVector({
makeFlatVector<int32_t>(
Expand Down Expand Up @@ -604,6 +657,59 @@ TEST_F(HashJoinTest, antiJoin) {
assertQuery(op, "SELECT t.c1 FROM t WHERE t.c0 NOT IN (SELECT c0 FROM u)");
}

TEST_F(HashJoinTest, antiJoinWithFilter) {
auto leftVectors = makeRowVector(
{"t0", "t1"},
{
makeFlatVector<int32_t>(1'000, [](auto row) { return row % 11; }),
makeFlatVector<int32_t>(1'000, [](auto row) { return row; }),
});

auto rightVectors = makeRowVector(
{"u0", "u1"},
{
makeFlatVector<int32_t>(1'234, [](auto row) { return row % 5; }),
makeFlatVector<int32_t>(1'234, [](auto row) { return row; }),
});

createDuckDbTable("t", {leftVectors});
createDuckDbTable("u", {rightVectors});

auto planNodeIdGenerator = std::make_shared<PlanNodeIdGenerator>();
auto op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"",
{"t0", "t1"},
core::JoinType::kAnti)
.planNode();

assertQuery(
op, "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0)");

op = PlanBuilder(planNodeIdGenerator)
.values({leftVectors})
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values({rightVectors})
.planNode(),
"t1 != u1",
{"t0", "t1"},
core::JoinType::kAnti)
.planNode();

assertQuery(
op,
"SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND t1 <> u1)");
}

TEST_F(HashJoinTest, dynamicFilters) {
const int32_t numSplits = 20;
const int32_t numRowsProbe = 1024;
Expand Down

0 comments on commit 98ae47b

Please sign in to comment.