-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spark sql avg agg function support decimal #6020
Conversation
✅ Deploy Preview for meta-velox canceled.
|
9bc7c32
to
f203376
Compare
@rui-mo Could you help review this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spark sql avg agg function's signatures is different from Presto, we need support Spark sql separately.
Can we also desribe how their signatures are different, and paste this part into PR description? We can also add the link to Spark's implementation there.
try { | ||
rawValues[i] = computeAvg(accumulator); | ||
} catch (const VeloxException& err) { | ||
if (err.message().find("overflow") != std::string::npos || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change this part to overflow flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mbasmanova , Spark SQL needs to return null when an overflow occurs, but the behavior of velox is to throw an exception, so we used to return a null value by catching the exception before. Do you have any other better suggestions for this part? For example, the framework level supports configuration whether to throw an exception or return a null value.
We tried to modify some methods to return std::nullopt to indicate overflow, but there are still many methods shared with presto, such as decimal multiplication, if modified, it will affect the behavior of presto.
cc @rui-mo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@liujiayi771 Let's start by describing the semantics of Spark's avg function and how it differs from Presto. Make sure to include this in the PR description. Once we have that information, we can brainstorm on how best to implement it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@liujiayi771 Let's start by describing the semantics of Spark's avg function and how it differs from Presto. Make sure to include this in the PR description. Once we have that information, we can brainstorm on how best to implement it.
The PR description has been updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mbasmanova When calculating the avg of decimal, checkedMultiply
will be called. This method may overflow. Can we add a nullOnFailure parameter to these methods like the PR of try cast? The default is false, which will throw an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will introduce multiply
in this PR https://github.com/facebookincubator/velox/pull/4613/files#diff-4b21d26191cbbdb9bc90cafd83aed144a4483b1c594f224cf3bbe67f401f5f83R87
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will introduce
multiply
in this PR https://github.com/facebookincubator/velox/pull/4613/files#diff-4b21d26191cbbdb9bc90cafd83aed144a4483b1c594f224cf3bbe67f401f5f83R87
OK. I will use spark decimal multiply after this PR merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When calculating the avg of decimal, checkedMultiply will be called. This method may overflow. Can we add a nullOnFailure parameter to these methods like the PR of try cast? The default is false, which will throw an exception.
@liujiayi771 That's a good question. Let's create a GitHub issue to discuss how to go about it. This particular functinon is used in hot paths, so we my need to make nullOnFailure a template parameter or introduce a different function.
CC: @laithsakka
auto countDecimal = accumulator->count; | ||
int128_t avg = 0; | ||
|
||
DecimalUtil::divideWithRoundUp<int128_t, int128_t, int128_t>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can wait for this to add overflow flag. https://github.com/facebookincubator/velox/pull/4613/files#diff-4b21d26191cbbdb9bc90cafd83aed144a4483b1c594f224cf3bbe67f401f5f83R86
|
||
DecimalUtil::divideWithRoundUp<int128_t, int128_t, int128_t>( | ||
avg, sum, countDecimal, false, sumRescale, 0); | ||
DecimalUtil::valueInRange(avg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is just check if it exceeds long or short decimal type threshold, but not check if the value with this precision can represent this value, maybe we should check as this https://github.com/facebookincubator/velox/blob/main/velox/type/DecimalUtil.h#L118
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I create a new valueInRangeWithPrecision function in decimal sum PR.
https://github.com/facebookincubator/velox/pull/5372/files
exec::AggregateFunctionSignatureBuilder() | ||
.integerVariable("a_precision") | ||
.integerVariable("a_scale") | ||
.integerVariable("r_precision", "min(38, a_precision + 4)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result precision is different with actual result type, maybe you should wait for https://github.com/facebookincubator/velox/pull/4613/files#diff-4baee9f82f9c347f753972415d345941d2c2a110b608d480b090650171da096bR399 to get the right signature
14f584f
to
65cb72a
Compare
bool /* mayPushdown */) override { | ||
decodedPartial_.decode(*args[0], rows); | ||
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base()); | ||
auto sumVector = baseRowVector->childAt(0)->as<SimpleVector<int128_t>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for arg0 to be int64_t type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will always be int128_t, according to
https://github.com/apache/spark/blob/13732922cca4d03de216d0ad2264ec9212fb63b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L2105-L2137
In Average, when the precision of decimal is less than or equal to 11, decimal will be converted to long for calculation, so the precision of sum decimal type (input's precision + 10) in this function must be a long decimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use VELOX_USER_CHECK_NOT_NULL
to ensure sum vector is valid.
// Spark use DECIMAL(20,0) to represent long value. | ||
const uint8_t countPrecision = 20, countScale = 0; | ||
auto [sumPrecision, sumScale] = | ||
getDecimalPrecisionScale(*this->sumType_.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rui-mo sumType needs to be calculated through inputType. For raw input step, sumType's precision = min(38, inputType's precision). For intermediate input setp, sumType = inputType.childAt(0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add similar test as TestOperator.scala#L457-L460 to ensure the functionality? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add similar test as TestOperator.scala#L457-L460 to ensure the functionality? Thanks.
This relies on being able to define variable inputs in velox's function signature. The function signature issue mentioned in another comment also relies on this. I remember @jinchengchenghh was making this modification. May need to be adapted later.
int128_t avg = 0; | ||
|
||
DecimalUtil::divideWithRoundUp<int128_t, int128_t, int128_t>( | ||
avg, validSum.value(), countDecimal, false, sumRescale, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0 corresponds to the parameter bRescale.
return rScale - fromScale + toScale; | ||
} | ||
|
||
inline static std::pair<uint8_t, uint8_t> computeResultPrecisionScale( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part looks similar to the precision and scale calculation of decimal divide. Maybe we can add helper functions to avoid duplication. cc @jinchengchenghh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can put these helper functions into DecimalUtils.h
, or have a new DecimalUtils used by Spark.
return adjustPrecisionScale(precision, scale); | ||
} | ||
|
||
inline static std::pair<uint8_t, uint8_t> adjustPrecisionScale( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
.argumentType("DECIMAL(a_precision, a_scale)") | ||
.intermediateType("ROW(DECIMAL(38 , a_scale), BIGINT)") | ||
.returnType("DECIMAL(r_precision, r_scale)") | ||
.build()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below extra registration is needed for partial merge stage of average. In this case, the intermediate type is ROW(DECIMAL(a_precision, a_scale), BIGINT)
and not restricted precision 38.
signatures.push_back(
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("ROW(DECIMAL(a_precision, a_scale), BIGINT)")
.returnType("DECIMAL(a_precision, a_scale)")
.build());
int128_t avg = 0; | ||
|
||
DecimalUtil::divideWithRoundUp<int128_t, int128_t, int128_t>( | ||
avg, validSum.value(), countDecimal, false, sumRescale, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0 corresponds to the parameter bRescale.
8a4ba83
to
834291c
Compare
velox/type/DecimalUtil.h
Outdated
@@ -173,11 +173,7 @@ class DecimalUtil { | |||
} | |||
// Check overflow. | |||
if (!valueInPrecisionRange(rescaledValue, toPrecision) || isOverflow) { | |||
VELOX_USER_FAIL( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait #5307 merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the modification here and mark the PR as draft until 5307 merge?
Gluten has conflict between 6020 and 5307 which break the auto PR pick now. We have to manually resolve it each time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FelixYBW But there is one problem with this. The decimal avg function requires the use of the modified rescaleWithRoundUp
method, which needs to be called with the "overflow" parameter for validation. Otherwise, there might be issues with Spark's unit testing. But if I pass these unmodified parameters in my PR, it will not compile successfully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually the code should like below, right?
if (!valueInPrecisionRange(rescaledValue, toPrecision) || isOverflow) {
if (throwIfOverFlow){
Velox_USER_FAIL()
} else{
isOverflow = true;
return return std::nullopt;
//after both 5307 and 6020 are merged, should we still return nullopt even isOverflow already pass the info to caller?
}
}
It's OK the PR can't be compile since it waits until 5307 is merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FelixYBW Done.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
This pull request has been automatically marked as stale because it has not had recent activity. If you'd still like this PR merged, please comment on the PR, make sure you've addressed reviewer comments, and rebase on the latest main. Thank you for your contributions! |
@liujiayi771 The PR is used by OAP/Velox. Can you update? Looks #5307 is merged already. |
@FelixYBW I have updated the branch for this PR, but this PR can no longer be reopened. In fact, I have been following up on the decimal avg, and meta needs to develop the new agg using the simple agg function interface. I have also completed the decimal avg using the simple interface, but it requires some preceding PR to be merged into the community. The development of this part of the work is waiting for the community to provide a prototype, you can check this in #9167. |
Resolve #5315
Spark's avg: https://github.com/apache/spark/blob/cf64008fce77b38d1237874b04f5ac124b01b3a8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
Presto's decimal avg uses VARBINARY as intermediate data, it will serialize LongDecimalWithOverflowState to VARBINARY. But in Spark, the intermediate data types of decimal avg are sum(decimal type) and count(long type). There are also differences between Spark and Presto in the return value type. The input and output types of Presto are consistent. In Spark, when the input type is decimal(s, p), the output type is decimal(s+4, p+4) , the type of the sum stored in the intermediate data is decimal(s, p+10). Therefore, we need to add some processing of decimal scale and precision to Spark's decimal avg.
In addition, Spark and Presto also have many inconsistencies in the handling of overflow. Presto always throws an exception when overflow occurs, but Spark returns null by default. We need to deal with partial avg overflow and final avg overflow, etc. Spark can use the sum to be null to indicate that the partial avg has overflowed. If this situation is encountered, the final calculation result of the current group can be directly set to null.
However, many methods of decimal calculation in the current Velox code throw exceptions when overflow occurs. Currently, in the implementation of Spark decimal avg, we need to catch these exceptions to determine whether overflow occurs.