Skip to content

Commit

Permalink
Fix decimal agg signature on partial companion function (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored and marin-ma committed Jan 2, 2024
1 parent 7d74919 commit ead41fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 3 additions & 1 deletion velox/functions/sparksql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ exec::AggregateRegistrationResult registerAverage(
auto inputScale = inputType->asShortDecimal().scale();
auto sumType =
DECIMAL(std::min(38, inputPrecision + 10), inputScale);
if (exec::isPartialOutput(step)) {
if (exec::isPartialOutput(step) ||
(step == core::AggregationNode::Step::kSingle &&
resultType->isRow())) {
return std::make_unique<
DecimalAverageAggregate<int64_t, int64_t>>(
resultType, sumType);
Expand Down
8 changes: 7 additions & 1 deletion velox/functions/sparksql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ using SumAggregate = SumAggregateBase<TInput, TAccumulator, ResultType, true>;
TypePtr getDecimalSumType(
const TypePtr& resultType,
core::AggregationNode::Step step) {
return exec::isPartialOutput(step) ? resultType->childAt(0) : resultType;
if (exec::isPartialOutput(step)) {
return resultType->childAt(0);
}
if (step == core::AggregationNode::Step::kSingle && resultType->isRow()) {
return resultType->childAt(0);
}
return resultType;
}
} // namespace

Expand Down

0 comments on commit ead41fb

Please sign in to comment.