Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Sep 19, 2024
1 parent e1e89f4 commit 86e5481
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 65 deletions.
34 changes: 21 additions & 13 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,14 +998,21 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
? inputSize
: outputBatchSize_;
outputTableRowsCapacity_ = outputBatchSize;
if (filter_ &&
(isLeftJoin(joinType_) || isFullJoin(joinType_) ||
isAntiJoin(joinType_))) {
// If we need non-matching probe side row, there is a possibility that such
// row exists at end of an input batch and being carried over in the next
// output batch, so we need to make extra room of one row in output.
++outputTableRowsCapacity_;
if (filter_) {
if (isLeftJoin(joinType_) || isFullJoin(joinType_) ||
isAntiJoin(joinType_)) {
// If we need non-matching probe side row, there is a possibility that
// such row exists at end of an input batch and being carried over in the
// next output batch, so we need to make extra room of one row in output.
++outputTableRowsCapacity_;
}

// Intialize 'leftSemiProjectIsNull_' for null aware lft semi join.
if (isLeftSemiProjectJoin(joinType_) && nullAware_) {
leftSemiProjectIsNull_.clearAll();
}
}

auto mapping = initializeRowNumberMapping(
outputRowMapping_, outputTableRowsCapacity_, pool());
auto* outputTableRows =
Expand Down Expand Up @@ -1088,11 +1095,15 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
}

if (accumulatedNumOutput > 0 ||
(filter_ != nullptr && numOut < numOutputBeforeFilter / 10 &&
(filter_ != nullptr && numOut < numOutputBeforeFilter / 2 &&
!emptyBuildSide)) {
accumulatedNumOutput += numOut;

// To avoid generating low seletivity / small vectors, continue the
// current loop to populate more results to 'outputRowMapping_' and
// 'outputTableRows_' until all rows in the current input are processed or
// the preferred number of rows has been produced.
if (!resultIter_->atEnd() &&
accumulatedNumOutput < outputTableRowsCapacity_ &&
accumulatedNumOutput < operatorCtx_->driverCtx()
->queryConfig()
.preferredOutputBatchBytes() &&
Expand All @@ -1103,7 +1114,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
outputTableRowsCapacity_ - accumulatedNumOutput);
outputTableRows =
outputTableRows_->asMutable<char*>() + accumulatedNumOutput;
outputBatchSize = outputBatchSize - numOut;
outputBatchSize -= numOut;
continue;
}
}
Expand Down Expand Up @@ -1473,9 +1484,6 @@ int32_t HashProbe::evalFilter(int32_t numRows, int32_t offset) {

if (nullAware_) {
leftSemiProjectIsNull_.resize(numRows + offset, false);
if (offset == 0) {
leftSemiProjectIsNull_.clearAll();
}

auto addLast = [&](auto row, std::optional<bool> passed) {
if (passed.has_value()) {
Expand Down
3 changes: 2 additions & 1 deletion velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class HashProbe : public Operator {
// for right join and full join.
RowVectorPtr getBuildSideOutput();

// Applies 'filter_' to 'outputTableRows_' and updates 'outputRowMapping_'.
// Applies 'filter_' to 'outputTableRows_' starting at 'offset' with 'size'
// length and updates 'outputRowMapping_'.
// Returns the number of passing rows.
vector_size_t evalFilter(vector_size_t numRows, vector_size_t offset);

Expand Down
92 changes: 41 additions & 51 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8127,60 +8127,50 @@ TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) {
core::PlanNodeId probeScanId;
core::PlanNodeId buildScanId;
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(probeVectors[0]->type()))
.capturePlanNodeId(probeScanId)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(buildVectors[0]->type()))
.capturePlanNodeId(buildScanId)
.planNode(),
"(t1 + u1) % 3 = 0",
{"t0", "t1", "match"},
core::JoinType::kLeftSemiProject)
.planNode();
SplitInput splitInput = {
{probeScanId,
{exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
{buildScanId,
{exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
auto makePlan = [&](bool nullAware) {
return PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(probeVectors[0]->type()))
.capturePlanNodeId(probeScanId)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(buildVectors[0]->type()))
.capturePlanNodeId(buildScanId)
.planNode(),
"(t1 + u1) % 3 = 0",
{"t0", "t1", "match"},
core::JoinType::kLeftSemiProject,
nullAware)
.planNode();
};
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.planNode(plan)
.inputSplits(splitInput)
.checkSpillStats(false)
.referenceQuery(
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t")
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
auto stats = task->taskStats();
for (auto& pipeline : stats.pipelineStats) {
for (auto op : pipeline.operatorStats) {
if (op.operatorType == "HashProbe") {
ASSERT_EQ(op.outputVectors, 1);
}
}
}
})
.run();
for (auto nullAware : {false, true}) {
auto plan = makePlan(nullAware);

HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.planNode(flipJoinSides(plan))
.inputSplits(splitInput)
.checkSpillStats(false)
.referenceQuery(
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t")
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
auto stats = task->taskStats();
for (auto& pipeline : stats.pipelineStats) {
for (auto op : pipeline.operatorStats) {
if (op.operatorType == "HashProbe") {
ASSERT_EQ(op.outputVectors, 1);
SplitInput splitInput = {
{probeScanId,
{exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
{buildScanId,
{exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
};

HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.planNode(plan)
.inputSplits(splitInput)
.checkSpillStats(false)
.referenceQuery(
"SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t")
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
auto stats = task->taskStats();
for (auto& pipeline : stats.pipelineStats) {
for (auto op : pipeline.operatorStats) {
if (op.operatorType == "HashProbe") {
ASSERT_EQ(op.outputVectors, 1);
}
}
}
}
})
.run();
})
.run();
}
}
} // namespace

0 comments on commit 86e5481

Please sign in to comment.