Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Aug 8, 2023
1 parent a2aaee6 commit 9bc7c32
Showing 1 changed file with 18 additions and 29 deletions.
47 changes: 18 additions & 29 deletions velox/functions/sparksql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,15 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
rowVector->resize(numGroups);
sumVector->resize(numGroups);
countVector->resize(numGroups);

uint64_t* rawNulls = this->getRawNulls(rowVector);
rowVector->clearAllNulls();

int64_t* rawCounts = countVector->mutableRawValues();
int128_t* rawSums = sumVector->mutableRawValues();

for (auto i = 0; i < numGroups; ++i) {
char* group = groups[i];
if (this->isNull(group)) {
rowVector->setNull(i, true);
} else {
this->clearNull(rawNulls, i);
auto* accumulator = this->decimalAccumulator(group);
rawCounts[i] = accumulator->count;
rawSums[i] = accumulator->sum;
}
auto* accumulator = this->decimalAccumulator(group);
rawCounts[i] = accumulator->count;
rawSums[i] = accumulator->sum;
}
}

Expand All @@ -152,26 +145,22 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
for (int32_t i = 0; i < numGroups; ++i) {
char* group = groups[i];
auto accumulator = this->decimalAccumulator(group);
if (this->isNull(group) || accumulator->count == 0) {
if (accumulator->count == 0) {
// In Spark, if all inputs are null, count will be 0,
// and the result of final avg will be null.
vector->setNull(i, true);
} else {
this->clearNull(rawNulls, i);
if (accumulator->overflow > 0) {
// Spark does not support ansi mode yet,
// and needs to return null when overflow
vector->setNull(i, true);
} else {
try {
rawValues[i] = computeAvg(accumulator);
} catch (const VeloxException& err) {
if (err.message().find("overflow") != std::string::npos ||
err.message().find("is not in the range of Decimal Type") !=
std::string::npos) {
// find overflow or out of long decimal range in computation
vector->setNull(i, true);
} else {
VELOX_FAIL("compute average failed");
}
try {
rawValues[i] = computeAvg(accumulator);
} catch (const VeloxException& err) {
if (err.message().find("overflow") != std::string::npos ||
err.message().find("is not in the range of Decimal Type") !=
std::string::npos) {
// find overflow or out of long decimal range in computation
vector->setNull(i, true);
} else {
VELOX_FAIL("compute average failed");
}
}
}
Expand Down Expand Up @@ -498,7 +487,7 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) {
resultType);
case TypeKind::DOUBLE:
case TypeKind::ROW:
if (resultType->childAt(0)->isLongDecimal()) {
if (inputType->childAt(0)->isLongDecimal()) {
return std::make_unique<
DecimalAverageAggregate<int128_t, int128_t>>(resultType);
}
Expand Down

0 comments on commit 9bc7c32

Please sign in to comment.