diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 494d681dbd88..2a5412776480 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -244,6 +244,10 @@ bool AggregationNode::canSpill(const QueryConfig& queryConfig) const { void AggregationNode::addDetails(std::stringstream& stream) const { stream << stepName(step_) << " "; + if (isPreGrouped()) { + stream << "STREAMING "; + } + if (!groupingKeys_.empty()) { stream << "["; addFields(stream, groupingKeys_); diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index 4e04b4bcae16..7c46401d267a 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -573,6 +573,19 @@ class AggregationNode : public PlanNode { return preGroupedKeys_; } + bool isPreGrouped() const { + return !preGroupedKeys_.empty() && + std::equal( + preGroupedKeys_.begin(), + preGroupedKeys_.end(), + groupingKeys_.begin(), + groupingKeys_.end(), + [](const FieldAccessTypedExprPtr& x, + const FieldAccessTypedExprPtr& y) -> bool { + return (*x == *y); + }); + } + const std::vector& aggregateNames() const { return aggregateNames_; } diff --git a/velox/exec/LocalPlanner.cpp b/velox/exec/LocalPlanner.cpp index d50a700474fe..f035333e895e 100644 --- a/velox/exec/LocalPlanner.cpp +++ b/velox/exec/LocalPlanner.cpp @@ -488,9 +488,7 @@ std::shared_ptr DriverFactory::createDriver( } else if ( auto aggregationNode = std::dynamic_pointer_cast(planNode)) { - if (!aggregationNode->preGroupedKeys().empty() && - aggregationNode->preGroupedKeys().size() == - aggregationNode->groupingKeys().size()) { + if (aggregationNode->isPreGrouped()) { operators.push_back(std::make_unique( id, ctx.get(), aggregationNode)); } else { diff --git a/velox/exec/tests/PlanNodeToStringTest.cpp b/velox/exec/tests/PlanNodeToStringTest.cpp index 3a344d2a9fb7..65b05715cd3c 100644 --- a/velox/exec/tests/PlanNodeToStringTest.cpp +++ b/velox/exec/tests/PlanNodeToStringTest.cpp @@ -271,6 +271,14 @@ TEST_F(PlanNodeToStringTest, aggregation) { ASSERT_EQ( "-- Aggregation[SINGLE [c0, group_id] sum_c1 := sum(ROW[\"c1\"]) global group IDs: [ 1, 2 ] Group Id key: group_id] -> c0:SMALLINT, group_id:BIGINT, sum_c1:BIGINT\n", plan->toString(true, false)); + + plan = PlanBuilder() + .values({data_}) + .partialStreamingAggregation({"c0"}, {"sum(c1) AS a"}) + .planNode(); + ASSERT_EQ( + "-- Aggregation[PARTIAL STREAMING [c0] a := sum(ROW[\"c1\"])] -> c0:SMALLINT, a:BIGINT\n", + plan->toString(true, false)); } TEST_F(PlanNodeToStringTest, groupId) {