Skip to content

Commit

Permalink
Enable right join in smj (#10148)
Browse files Browse the repository at this point in the history
Summary:
The semantics of the right join are similar to  the left join, so we referenced the implementation of the left join to achieve the implementation of the right join.

Pull Request resolved: #10148

Reviewed By: bikramSingh91

Differential Revision: D59176120

Pulled By: pedroerp

fbshipit-source-id: 95184725dfa5fea9317c822d7761507bc49fca9b
  • Loading branch information
JkSelf authored and facebook-github-bot committed Jun 29, 2024
1 parent c54e59d commit 0ef0ac8
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 60 deletions.
177 changes: 143 additions & 34 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

namespace facebook::velox::exec {

namespace {
bool supportsMergeJoin(std::shared_ptr<const core::MergeJoinNode> joinNode) {
return joinNode->isInnerJoin() || joinNode->isLeftJoin() ||
joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin() ||
joinNode->isAntiJoin() || joinNode->isRightJoin();
}
} // namespace
MergeJoin::MergeJoin(
int32_t operatorId,
DriverCtx* driverCtx,
Expand All @@ -35,10 +42,9 @@ MergeJoin::MergeJoin(
numKeys_{joinNode->leftKeys().size()},
joinNode_(joinNode) {
VELOX_USER_CHECK(
joinNode_->isInnerJoin() || joinNode_->isLeftJoin() ||
joinNode_->isLeftSemiFilterJoin() ||
joinNode_->isRightSemiFilterJoin() || joinNode_->isAntiJoin(),
"Merge join supports only inner, left and left semi joins. Other join types are not supported yet.");
supportsMergeJoin(joinNode_),
"The join type is not supported by merge join: ",
joinTypeName(joinNode_->joinType()));
}

void MergeJoin::initialize() {
Expand Down Expand Up @@ -89,13 +95,14 @@ void MergeJoin::initialize() {
if (joinNode_->filter()) {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin()) {
leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool());
if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
// Anti join needs to track the left side rows that have no match on the
// right.
leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool());
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}

joinNode_.reset();
Expand Down Expand Up @@ -183,15 +190,18 @@ BlockingReason MergeJoin::isBlocked(ContinueFuture* future) {
}

bool MergeJoin::needsInput() const {
if (isRightJoin(joinType_)) {
return (input_ == nullptr || rightInput_ == nullptr);
}
return input_ == nullptr;
}

void MergeJoin::addInput(RowVectorPtr input) {
input_ = std::move(input);
index_ = 0;

if (leftJoinTracker_) {
leftJoinTracker_->resetLastVector();
if (joinTracker_) {
joinTracker_->resetLastVector();
}
}

Expand Down Expand Up @@ -269,16 +279,36 @@ void copyRow(
void MergeJoin::addOutputRowForLeftJoin(
const RowVectorPtr& left,
vector_size_t leftIndex) {
VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_));
rawLeftIndices_[outputSize_] = leftIndex;

for (const auto& projection : rightProjections_) {
const auto& target = output_->childAt(projection.outputChannel);
target->setNull(outputSize_, true);
}

if (leftJoinTracker_) {
if (joinTracker_) {
// Record left-side row with no match on the right side.
leftJoinTracker_->addMiss(outputSize_);
joinTracker_->addMiss(outputSize_);
}

++outputSize_;
}

void MergeJoin::addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex) {
VELOX_USER_CHECK(isRightJoin(joinType_));
rawRightIndices_[outputSize_] = rightIndex;

for (const auto& projection : leftProjections_) {
const auto& target = output_->childAt(projection.outputChannel);
target->setNull(outputSize_, true);
}

if (joinTracker_) {
// Record right-side row with no match on the left side.
joinTracker_->addMiss(outputSize_);
}

++outputSize_;
Expand Down Expand Up @@ -320,18 +350,23 @@ void MergeJoin::addOutputRow(
copyRow(left, leftIndex, filterInput_, outputSize_, filterLeftInputs_);
copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_);

if (leftJoinTracker_) {
// Record left-side row with a match on the right-side.
leftJoinTracker_->addMatch(left, leftIndex, outputSize_);
if (joinTracker_) {
if (isRightJoin(joinType_)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
} else {
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
}
}
}

// Anti join needs to track the left side rows that have no match on the
// right.
if (isAntiJoin(joinType_)) {
VELOX_CHECK(leftJoinTracker_);
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
leftJoinTracker_->addMatch(left, leftIndex, outputSize_);
joinTracker_->addMatch(left, leftIndex, outputSize_);
}

++outputSize_;
Expand All @@ -348,6 +383,10 @@ bool MergeJoin::prepareOutput(
return true;
}

if (isRightJoin(joinType_) && right != currentRight_) {
return true;
}

// If there is a new right, we need to flatten the dictionary.
if (!isRightFlattened_ && right && currentRight_ != right) {
flattenRightProjections();
Expand All @@ -363,14 +402,23 @@ bool MergeJoin::prepareOutput(
rightIndices_ = allocateIndices(outputBatchSize_, pool());
rawRightIndices_ = rightIndices_->asMutable<vector_size_t>();

// Create output dictionary vectors for left projections.
// Create left side projection outputs.
std::vector<VectorPtr> localColumns(outputType_->size());
for (const auto& projection : leftProjections_) {
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
{},
leftIndices_,
outputBatchSize_,
newLeft->childAt(projection.inputChannel));
if (newLeft == nullptr) {
for (const auto& projection : leftProjections_) {
localColumns[projection.outputChannel] = BaseVector::create(
outputType_->childAt(projection.outputChannel),
outputBatchSize_,
operatorCtx_->pool());
}
} else {
for (const auto& projection : leftProjections_) {
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
{},
leftIndices_,
outputBatchSize_,
newLeft->childAt(projection.inputChannel));
}
}
currentLeft_ = newLeft;

Expand Down Expand Up @@ -556,7 +604,7 @@ vector_size_t firstNonNull(

RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) {
auto numRows = output->size();
const auto& filterRows = leftJoinTracker_->matchingRows(numRows);
const auto& filterRows = joinTracker_->matchingRows(numRows);
auto numPassed = 0;

BufferPtr indices = allocateIndices(numRows, pool());
Expand Down Expand Up @@ -738,6 +786,35 @@ RowVectorPtr MergeJoin::doGetOutput() {
output_->resize(outputSize_);
return std::move(output_);
}
} else if (isRightJoin(joinType_)) {
if (rightInput_ && noMoreInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}

while (true) {
if (outputSize_ == outputBatchSize_) {
return std::move(output_);
}

addOutputRowForRightJoin(rightInput_, rightIndex_);

++rightIndex_;
if (rightIndex_ == rightInput_->size()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
return nullptr;
}
}
}

if (noMoreRightInput_ && output_) {
output_->resize(outputSize_);
return std::move(output_);
}
} else {
if (noMoreInput_ || noMoreRightInput_) {
if (output_) {
Expand Down Expand Up @@ -770,9 +847,11 @@ RowVectorPtr MergeJoin::doGetOutput() {
return std::move(output_);
}
addOutputRowForLeftJoin(input_, index_);
++index_;
} else {
index_ = firstNonNull(input_, leftKeys_, index_ + 1);
}

++index_;
if (index_ == input_->size()) {
// Ran out of rows on the left side.
input_ = nullptr;
Expand All @@ -783,7 +862,24 @@ RowVectorPtr MergeJoin::doGetOutput() {

// Catch up rightInput_ with input_.
while (compareResult > 0) {
rightIndex_ = firstNonNull(rightInput_, rightKeys_, rightIndex_ + 1);
if (isRightJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}

if (outputSize_ == outputBatchSize_) {
return std::move(output_);
}

addOutputRowForRightJoin(rightInput_, rightIndex_);
++rightIndex_;
} else {
rightIndex_ = firstNonNull(rightInput_, rightKeys_, rightIndex_ + 1);
}

if (rightIndex_ == rightInput_->size()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
Expand Down Expand Up @@ -862,8 +958,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
auto rawIndices = indices->asMutable<vector_size_t>();
vector_size_t numPassed = 0;

if (leftJoinTracker_) {
const auto& filterRows = leftJoinTracker_->matchingRows(numRows);
if (joinTracker_) {
const auto& filterRows = joinTracker_->matchingRows(numRows);

if (!filterRows.hasSelections()) {
// No matches in the output, no need to evaluate the filter.
Expand All @@ -878,9 +974,16 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
if (!isAntiJoin(joinType_)) {
rawIndices[numPassed++] = row;

for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
if (!isRightJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
} else {
for (auto& projection : leftProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
}
}
};
Expand All @@ -890,7 +993,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);

leftJoinTracker_->processFilterResult(i, passed, onMiss);
joinTracker_->processFilterResult(i, passed, onMiss);

if (isAntiJoin(joinType_)) {
if (!passed) {
Expand Down Expand Up @@ -927,8 +1030,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
// 2. leftMatch_ may not be nullopt, but may be related to a different
// (subsequent) left key. So we check if the last row in the batch has the
// same left row number as the last key match.
if (!leftMatch_ || !leftJoinTracker_->isCurrentLeftMatch(numRows - 1)) {
leftJoinTracker_->noMoreFilterResults(onMiss);
if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) {
joinTracker_->noMoreFilterResults(onMiss);
}
} else {
filterRows_.resize(numRows);
Expand Down Expand Up @@ -966,6 +1069,12 @@ void MergeJoin::evaluateFilter(const SelectivityVector& rows) {
}

bool MergeJoin::isFinished() {
if (isRightJoin(joinType_)) {
// If all rows on both the left and right sides match, we must also verify
// the 'noMoreInput_' on the left side to ensure that all results are
// complete.
return noMoreInput_ && noMoreRightInput_ && rightInput_ == nullptr;
}
return noMoreInput_ && input_ == nullptr;
}

Expand Down
18 changes: 13 additions & 5 deletions velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ class MergeJoin : public Operator {
const RowVectorPtr& left,
vector_size_t leftIndex);

/// Adds one row of output for a right-side row with no left-side match.
/// Copies values from the 'rightIndex' row of 'right' and fills in nulls
/// for columns that correspond to the right side.
void addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex);

/// Evaluates join filter on 'filterInput_' and returns 'output' that contains
/// a subset of rows on which the filter passed. Returns nullptr if no rows
/// passed the filter.
Expand All @@ -231,9 +238,9 @@ class MergeJoin : public Operator {
/// rows from the left side that have a match on the right.
RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output);

/// As we populate the results of the left join, we track whether a given
/// As we populate the results of the join, we track whether a given
/// output row is a result of a match between left and right sides or a miss.
/// We use LeftJoinTracker::addMatch and addMiss methods for that.
/// We use JoinTracker::addMatch and addMiss methods for that.
///
/// The semantic of the filter is to include at least one left side row in the
/// output after filters are applied. Therefore:
Expand All @@ -256,8 +263,8 @@ class MergeJoin : public Operator {
/// block, we keep the subset of passing rows. However, if the filter failed
/// on all rows in such a block, we add one of these rows back and update
/// build-side columns to null.
struct LeftJoinTracker {
LeftJoinTracker(vector_size_t numRows, memory::MemoryPool* pool)
struct JoinTracker {
JoinTracker(vector_size_t numRows, memory::MemoryPool* pool)
: matchingRows_{numRows, false} {
leftRowNumbers_ = AlignedBuffer::allocate<vector_size_t>(numRows, pool);
rawLeftRowNumbers_ = leftRowNumbers_->asMutable<vector_size_t>();
Expand Down Expand Up @@ -391,7 +398,8 @@ class MergeJoin : public Operator {
bool currentRowPassed_{false};
};

std::optional<LeftJoinTracker> leftJoinTracker_{std::nullopt};
/// Used to record both left and right join.
std::optional<JoinTracker> joinTracker_{std::nullopt};

// Indices buffer used by the output dictionaries. All projection from the
// left share `leftIndices_`, and projections in the right share
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/fuzzer/JoinFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ void JoinFuzzer::makeAlternativePlans(
// Use OrderBy + MergeJoin
if (joinNode->isInnerJoin() || joinNode->isLeftJoin() ||
joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin() ||
joinNode->isAntiJoin()) {
joinNode->isAntiJoin() || joinNode->isRightJoin()) {
auto planWithSplits = makeMergeJoinPlan(
joinType, probeKeys, buildKeys, probeInput, buildInput, outputColumns);
plans.push_back(planWithSplits);
Expand Down
Loading

0 comments on commit 0ef0ac8

Please sign in to comment.