Skip to content

Commit

Permalink
[facebookincubator#6020 ] Spark sql avg agg function support decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored and zhztheplayer committed Jul 25, 2024
1 parent 9941efe commit 3e49521
Show file tree
Hide file tree
Showing 4 changed files with 504 additions and 25 deletions.
6 changes: 4 additions & 2 deletions velox/functions/lib/aggregates/AverageAggregateBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ namespace facebook::velox::functions::aggregate {
void checkAvgIntermediateType(const TypePtr& type) {
VELOX_USER_CHECK(
type->isRow() || type->isVarbinary(),
"Input type for final average must be row type or varbinary type.");
"Input type for final average must be row type or varbinary type, find {}",
type->toString());
if (type->kind() == TypeKind::VARBINARY) {
return;
}
VELOX_USER_CHECK(
type->childAt(0)->kind() == TypeKind::DOUBLE ||
type->childAt(0)->isLongDecimal(),
"Input type for sum in final average must be double or long decimal type.")
"Input type for sum in final average must be double or long decimal type, find {}",
type->childAt(0)->toString());
VELOX_USER_CHECK_EQ(
type->childAt(1)->kind(),
TypeKind::BIGINT,
Expand Down
8 changes: 5 additions & 3 deletions velox/functions/lib/aggregates/DecimalAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class DecimalAggregate : public exec::Aggregate {
}

int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
return alignof(LongDecimalWithOverflowState);
}

void addRawInput(
Expand Down Expand Up @@ -275,7 +275,9 @@ class DecimalAggregate : public exec::Aggregate {
}

virtual TResultType computeFinalValue(
LongDecimalWithOverflowState* accumulator) = 0;
LongDecimalWithOverflowState* accumulator) {
return 0;
};

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
Expand Down Expand Up @@ -327,11 +329,11 @@ class DecimalAggregate : public exec::Aggregate {
}
}

private:
inline LongDecimalWithOverflowState* decimalAccumulator(char* group) {
return exec::Aggregate::value<LongDecimalWithOverflowState>(group);
}

private:
DecodedVector decodedRaw_;
DecodedVector decodedPartial_;
};
Expand Down
Loading

0 comments on commit 3e49521

Please sign in to comment.