diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index edd6b78d833f..6c3abe193406 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -998,14 +998,21 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { ? inputSize : outputBatchSize_; outputTableRowsCapacity_ = outputBatchSize; - if (filter_ && - (isLeftJoin(joinType_) || isFullJoin(joinType_) || - isAntiJoin(joinType_))) { - // If we need non-matching probe side row, there is a possibility that such - // row exists at end of an input batch and being carried over in the next - // output batch, so we need to make extra room of one row in output. - ++outputTableRowsCapacity_; + if (filter_) { + if (isLeftJoin(joinType_) || isFullJoin(joinType_) || + isAntiJoin(joinType_)) { + // If we need non-matching probe side row, there is a possibility that + // such row exists at end of an input batch and being carried over in the + // next output batch, so we need to make extra room of one row in output. + ++outputTableRowsCapacity_; + } + + // Intialize 'leftSemiProjectIsNull_' for null aware lft semi join. + if (isLeftSemiProjectJoin(joinType_) && nullAware_) { + leftSemiProjectIsNull_.clearAll(); + } } + auto mapping = initializeRowNumberMapping( outputRowMapping_, outputTableRowsCapacity_, pool()); auto* outputTableRows = @@ -1088,11 +1095,15 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { } if (accumulatedNumOutput > 0 || - (filter_ != nullptr && numOut < numOutputBeforeFilter / 10 && + (filter_ != nullptr && numOut < numOutputBeforeFilter / 2 && !emptyBuildSide)) { accumulatedNumOutput += numOut; + + // To avoid generating low seletivity / small vectors, continue the + // current loop to populate more results to 'outputRowMapping_' and + // 'outputTableRows_' until all rows in the current input are processed or + // the preferred number of rows has been produced. if (!resultIter_->atEnd() && - accumulatedNumOutput < outputTableRowsCapacity_ && accumulatedNumOutput < operatorCtx_->driverCtx() ->queryConfig() .preferredOutputBatchBytes() && @@ -1103,7 +1114,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { outputTableRowsCapacity_ - accumulatedNumOutput); outputTableRows = outputTableRows_->asMutable() + accumulatedNumOutput; - outputBatchSize = outputBatchSize - numOut; + outputBatchSize -= numOut; continue; } } @@ -1473,9 +1484,6 @@ int32_t HashProbe::evalFilter(int32_t numRows, int32_t offset) { if (nullAware_) { leftSemiProjectIsNull_.resize(numRows + offset, false); - if (offset == 0) { - leftSemiProjectIsNull_.clearAll(); - } auto addLast = [&](auto row, std::optional passed) { if (passed.has_value()) { diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 7c2d377aa85d..4680c8dcfdd8 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -128,7 +128,8 @@ class HashProbe : public Operator { // for right join and full join. RowVectorPtr getBuildSideOutput(); - // Applies 'filter_' to 'outputTableRows_' and updates 'outputRowMapping_'. + // Applies 'filter_' to 'outputTableRows_' starting at 'offset' with 'size' + // length and updates 'outputRowMapping_'. // Returns the number of passing rows. vector_size_t evalFilter(vector_size_t numRows, vector_size_t offset); diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index bceecd72e242..bacf8e04b498 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -8127,60 +8127,50 @@ TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) { core::PlanNodeId probeScanId; core::PlanNodeId buildScanId; auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(probeVectors[0]->type())) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(buildVectors[0]->type())) - .capturePlanNodeId(buildScanId) - .planNode(), - "(t1 + u1) % 3 = 0", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject) - .planNode(); - SplitInput splitInput = { - {probeScanId, - {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, - {buildScanId, - {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, + auto makePlan = [&](bool nullAware) { + return PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(probeVectors[0]->type())) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(buildVectors[0]->type())) + .capturePlanNodeId(buildScanId) + .planNode(), + "(t1 + u1) % 3 = 0", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject, + nullAware) + .planNode(); }; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto stats = task->taskStats(); - for (auto& pipeline : stats.pipelineStats) { - for (auto op : pipeline.operatorStats) { - if (op.operatorType == "HashProbe") { - ASSERT_EQ(op.outputVectors, 1); - } - } - } - }) - .run(); + for (auto nullAware : {false, true}) { + auto plan = makePlan(nullAware); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto stats = task->taskStats(); - for (auto& pipeline : stats.pipelineStats) { - for (auto op : pipeline.operatorStats) { - if (op.operatorType == "HashProbe") { - ASSERT_EQ(op.outputVectors, 1); + SplitInput splitInput = { + {probeScanId, + {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, + {buildScanId, + {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, + }; + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto stats = task->taskStats(); + for (auto& pipeline : stats.pipelineStats) { + for (auto op : pipeline.operatorStats) { + if (op.operatorType == "HashProbe") { + ASSERT_EQ(op.outputVectors, 1); + } } } - } - }) - .run(); + }) + .run(); + } } } // namespace