Skip to content

Commit

Permalink
Spark sql avg agg function support decimal (facebookincubator#6020)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored and marin-ma committed Jan 5, 2024
1 parent bb5d255 commit 421454a
Show file tree
Hide file tree
Showing 5 changed files with 527 additions and 36 deletions.
6 changes: 4 additions & 2 deletions velox/functions/lib/aggregates/AverageAggregateBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ namespace facebook::velox::functions::aggregate {
void checkAvgIntermediateType(const TypePtr& type) {
VELOX_USER_CHECK(
type->isRow() || type->isVarbinary(),
"Input type for final average must be row type or varbinary type.");
"Input type for final average must be row type or varbinary type, find {}",
type->toString());
if (type->kind() == TypeKind::VARBINARY) {
return;
}
VELOX_USER_CHECK(
type->childAt(0)->kind() == TypeKind::DOUBLE ||
type->childAt(0)->isLongDecimal(),
"Input type for sum in final average must be double or long decimal type.")
"Input type for sum in final average must be double or long decimal type, find {}",
type->childAt(0)->toString());
VELOX_USER_CHECK_EQ(
type->childAt(1)->kind(),
TypeKind::BIGINT,
Expand Down
11 changes: 7 additions & 4 deletions velox/functions/lib/aggregates/DecimalAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class DecimalAggregate : public exec::Aggregate {
explicit DecimalAggregate(TypePtr resultType) : exec::Aggregate(resultType) {}

int32_t accumulatorFixedWidthSize() const override {
return sizeof(DecimalAggregate);
return sizeof(LongDecimalWithOverflowState);
}

int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
return alignof(LongDecimalWithOverflowState);
}

void initializeNewGroups(
Expand Down Expand Up @@ -287,7 +287,9 @@ class DecimalAggregate : public exec::Aggregate {
}

virtual TResultType computeFinalValue(
LongDecimalWithOverflowState* accumulator) = 0;
LongDecimalWithOverflowState* accumulator) {
return 0;
};

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
Expand Down Expand Up @@ -329,11 +331,12 @@ class DecimalAggregate : public exec::Aggregate {
accumulator->count += 1;
}

private:
protected:
inline LongDecimalWithOverflowState* decimalAccumulator(char* group) {
return exec::Aggregate::value<LongDecimalWithOverflowState>(group);
}

private:
DecodedVector decodedRaw_;
DecodedVector decodedPartial_;
};
Expand Down
Loading

0 comments on commit 421454a

Please sign in to comment.