Skip to content

Commit

Permalink
Rewrite some common reduce patterns with array_sum and transform (#10406
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #10406

Lambda expression in Velox is interpreted instead of compiled like in
Presto Java, this makes us losing in performance in places like `reduce`.  By
rewriting some common patterns of `reduce`, we improve the performance
significantly (more than 45x on a query).

Also fix a bug in `TryExpr` where we are not checking enough space for result
nulls.

bypass-github-export-checks

Reviewed By: oerling

Differential Revision: D59406443

fbshipit-source-id: db47dcc3ab29b5a61fca83054cf60f6897b3f7a0
  • Loading branch information
Yuhta authored and facebook-github-bot committed Jul 9, 2024
1 parent 3d19d04 commit 19be763
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 41 deletions.
4 changes: 3 additions & 1 deletion velox/expression/TryExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ void TryExpr::nullOutErrors(
// Wrap in dictionary indices all pointing to index 0.
auto indices = allocateIndices(size, context.pool());
result = BaseVector::wrapInDictionary(nulls, indices, size, result);
} else if (result.unique() && result->isNullsWritable()) {
} else if (
result.unique() && result->isNullsWritable() &&
result->size() >= rows.end()) {
auto* rawNulls = result->mutableRawNulls();
rows.applyToSelected([&](auto row) {
if (errors->hasErrorAt(row)) {
Expand Down
24 changes: 24 additions & 0 deletions velox/expression/tests/ExprCompilerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,30 @@ TEST_F(ExprCompilerTest, rewrites) {
makeTypedExpr("c1 + 5", rowType),
},
execCtx_.get()));

auto exprSet = compile(makeTypedExpr(
"reduce(c0, 1, (s, x) -> s + x * 2, s -> s)",
ROW({"c0"}, {ARRAY(BIGINT())})));
ASSERT_EQ(exprSet->size(), 1);
ASSERT_EQ(
exprSet->expr(0)->toString(),
"plus(1:BIGINT, array_sum_propagate_element_null(transform(c0, (x) -> multiply(x, 2:BIGINT))))");

exprSet = compile(makeTypedExpr(
"reduce(c0, 1, (s, x) -> (s + 2) - x, s -> s)",
ROW({"c0"}, {ARRAY(BIGINT())})));
ASSERT_EQ(exprSet->size(), 1);
ASSERT_EQ(
exprSet->expr(0)->toString(),
"plus(1:BIGINT, array_sum_propagate_element_null(transform(c0, (x) -> minus(2:BIGINT, x))))");

exprSet = compile(makeTypedExpr(
"reduce(c0, 1, (s, x) -> if(x % 2 = 0, s + 3, s), s -> s)",
ROW({"c0"}, {ARRAY(BIGINT())})));
ASSERT_EQ(exprSet->size(), 1);
ASSERT_EQ(
exprSet->expr(0)->toString(),
"plus(1:BIGINT, array_sum_propagate_element_null(transform(c0, (x) -> switch(eq(mod(x, 2:BIGINT), 0:BIGINT), 3:BIGINT, 0:BIGINT))))");
}

TEST_F(ExprCompilerTest, eliminateUnnecessaryCast) {
Expand Down
72 changes: 53 additions & 19 deletions velox/functions/prestosql/ArraySum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ namespace {
/// See documentation at https://prestodb.io/docs/current/functions/array.html
///

template <typename TInput, typename TOutput>
template <typename TInput, typename TOutput, bool kPropagateElementNull>
class ArraySumFunction : public exec::VectorFunction {
public:
template <bool mayHaveNulls, typename DataAtFunc, typename IsNullFunc>
TOutput applyCore(
vector_size_t row,
const ArrayVector* arrayVector,
DataAtFunc&& dataAtFunc,
IsNullFunc&& isNullFunc) const {
IsNullFunc&& isNullFunc,
bool& nullResult) const {
auto start = arrayVector->offsetAt(row);
auto end = start + arrayVector->sizeAt(row);
TOutput sum = 0;
Expand All @@ -53,14 +54,21 @@ class ArraySumFunction : public exec::VectorFunction {

for (auto i = start; i < end; i++) {
if constexpr (mayHaveNulls) {
bool isNull = isNullFunc(i);
if (!isNull) {
if (isNullFunc(i)) {
if constexpr (kPropagateElementNull) {
nullResult = true;
return 0;
}
} else {
addElement(sum, dataAtFunc(i));
}
} else {
addElement(sum, dataAtFunc(i));
}
}
if constexpr (mayHaveNulls && kPropagateElementNull) {
nullResult = false;
}
return sum;
}

Expand All @@ -73,9 +81,18 @@ class ArraySumFunction : public exec::VectorFunction {
DataAtFunc&& dataAtFunc,
IsNullFunc&& isNullFunc) const {
context.applyToSelectedNoThrow(rows, [&](auto row) {
resultValues->set(
row,
applyCore<mayHaveNulls>(row, arrayVector, dataAtFunc, isNullFunc));
bool nullResult;
auto sum = applyCore<mayHaveNulls>(
row, arrayVector, dataAtFunc, isNullFunc, nullResult);
if constexpr (mayHaveNulls && kPropagateElementNull) {
if (nullResult) {
resultValues->setNull(row, true);
} else {
resultValues->set(row, sum);
}
} else {
resultValues->set(row, sum);
}
});
}

Expand Down Expand Up @@ -132,19 +149,22 @@ class ArraySumFunction : public exec::VectorFunction {
exec::LocalDecodedVector elements(context, *elementsVector, elementsRows);

TOutput sum;
bool nullResult = false;
try {
if (elementsVector->mayHaveNulls()) {
sum = applyCore<true>(
arrayRow,
arrayVector,
[&](auto index) { return elements->valueAt<TInput>(index); },
[&](auto index) { return elements->isNullAt(index); });
[&](auto index) { return elements->isNullAt(index); },
nullResult);
} else {
sum = applyCore<false>(
arrayRow,
arrayVector,
[&](auto index) { return elements->valueAt<TInput>(index); },
[&](auto index) { return elements->isNullAt(index); });
[&](auto index) { return elements->isNullAt(index); },
nullResult);
}
} catch (...) {
context.setErrors(rows, std::current_exception());
Expand All @@ -154,7 +174,7 @@ class ArraySumFunction : public exec::VectorFunction {
std::make_shared<ConstantVector<TOutput>>(
context.pool(),
rows.end(),
false /*isNull*/,
nullResult,
outputType,
std::move(sum)),
rows,
Expand Down Expand Up @@ -201,6 +221,7 @@ class ArraySumFunction : public exec::VectorFunction {
};

// Create function.
template <bool kPropagateElementNull>
std::shared_ptr<exec::VectorFunction> create(
const std::string& /* name */,
const std::vector<exec::VectorFunctionArg>& inputArgs,
Expand All @@ -211,30 +232,38 @@ std::shared_ptr<exec::VectorFunction> create(
case TypeKind::TINYINT: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::TINYINT>::NativeType,
int64_t>>();
int64_t,
kPropagateElementNull>>();
}
case TypeKind::SMALLINT: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::SMALLINT>::NativeType,
int64_t>>();
int64_t,
kPropagateElementNull>>();
}
case TypeKind::INTEGER: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::INTEGER>::NativeType,
int64_t>>();
int64_t,
kPropagateElementNull>>();
}
case TypeKind::BIGINT: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::BIGINT>::NativeType,
int64_t>>();
int64_t,
kPropagateElementNull>>();
}
case TypeKind::REAL: {
return std::make_shared<
ArraySumFunction<TypeTraits<TypeKind::REAL>::NativeType, double>>();
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::REAL>::NativeType,
double,
kPropagateElementNull>>();
}
case TypeKind::DOUBLE: {
return std::make_shared<
ArraySumFunction<TypeTraits<TypeKind::DOUBLE>::NativeType, double>>();
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::DOUBLE>::NativeType,
double,
kPropagateElementNull>>();
}
default: {
VELOX_FAIL("Unsupported Type")
Expand Down Expand Up @@ -266,6 +295,11 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
} // namespace

// Register function.
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(udf_array_sum, signatures(), create);
void registerVectorFunction_udf_array_sum(const std::string& name) {
facebook::velox::exec::registerStatefulVectorFunction(
name, signatures(), create<false>);
facebook::velox::exec::registerStatefulVectorFunction(
name + "_propagate_element_null", signatures(), create<true>);
}

} // namespace facebook::velox::functions
Loading

0 comments on commit 19be763

Please sign in to comment.