Skip to content

Commit

Permalink
Add support for custom comparison in Presto's arrays intersect/except…
Browse files Browse the repository at this point in the history
…/overlap UDFs (#11164)

Summary:
Pull Request resolved: #11164

Update Presto's array_intersect, array_except, and array_overalp UDFs to work with
types that provide custom comparison.  We can reuse the implementation of ValueSet
for complex types, since that just uses the compare and hash functions provided by
the Vector.  With #11022 these just
invoke the Type's custom implementations of these functions.

Reviewed By: xiaoxmeng

Differential Revision: D63851047

fbshipit-source-id: 9633c8b40d9cfe13964cb5e0675a30b95d378bbc
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 10, 2024
1 parent 01ed92c commit a6899b7
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 28 deletions.
89 changes: 61 additions & 28 deletions velox/functions/prestosql/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,27 @@ struct SetWithNull {
bool hasNull{false};
};

struct ComplexTypeEntry {
// This class is used as the entry in a set when the native type cannot be used
// directly. In particular, for complex types and custom types that provide
// custom comparison operators.
struct WrappedVectorEntry {
const uint64_t hash;
const BaseVector* baseVector;
const vector_size_t index;
};

template <>
struct SetWithNull<ComplexTypeEntry> {
struct SetWithNull<WrappedVectorEntry> {
struct Hash {
size_t operator()(const ComplexTypeEntry& entry) const {
size_t operator()(const WrappedVectorEntry& entry) const {
return entry.hash;
}
};

struct EqualTo {
bool operator()(const ComplexTypeEntry& left, const ComplexTypeEntry& right)
const {
bool operator()(
const WrappedVectorEntry& left,
const WrappedVectorEntry& right) const {
return left.baseVector
->equalValueAt(
right.baseVector,
Expand All @@ -77,7 +81,7 @@ struct SetWithNull<ComplexTypeEntry> {
}
};

folly::F14FastSet<ComplexTypeEntry, Hash, EqualTo> set;
folly::F14FastSet<WrappedVectorEntry, Hash, EqualTo> set;
bool hasNull{false};

SetWithNull(vector_size_t initialSetSize = kInitialSetSize) {
Expand All @@ -88,15 +92,15 @@ struct SetWithNull<ComplexTypeEntry> {
const auto vector = decodedElements->base();
const auto index = decodedElements->index(offset);
const uint64_t hash = vector->hashValueAt(index);
return set.insert(ComplexTypeEntry{hash, vector, index}).second;
return set.insert(WrappedVectorEntry{hash, vector, index}).second;
}

size_t count(const DecodedVector* decodedElements, vector_size_t offset)
const {
const auto vector = decodedElements->base();
const auto index = decodedElements->index(offset);
const uint64_t hash = vector->hashValueAt(index);
return set.count(ComplexTypeEntry{hash, vector, index});
return set.count(WrappedVectorEntry{hash, vector, index});
}

void reset() {
Expand Down Expand Up @@ -460,17 +464,9 @@ SetWithNull<T> validateConstantVectorAndGenerateSet(
return constantSet;
}

template <bool isIntersect, TypeKind kind>
template <bool isIntersect, typename SetEntryT>
std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
ComplexTypeEntry>;

VELOX_CHECK_EQ(inputArgs.size(), 2);
BaseVector* rhs = inputArgs[1].constantValue.get();

const BaseVector* rhs) {
// We don't optimize the case where lhs is a constant expression for
// array_intersect() because that would make this function non-deterministic.
// For example, a constant lhs would mean the constantSet is created based on
Expand All @@ -480,10 +476,31 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
//
// If rhs is a constant value:
if (rhs != nullptr) {
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>(
validateConstantVectorAndGenerateSet<T>(rhs));
return std::make_shared<
ArrayIntersectExceptFunction<isIntersect, SetEntryT>>(
validateConstantVectorAndGenerateSet<SetEntryT>(rhs));
} else {
return std::make_shared<
ArrayIntersectExceptFunction<isIntersect, SetEntryT>>();
}
}

template <bool isIntersect, TypeKind kind>
std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
const std::vector<exec::VectorFunctionArg>& inputArgs,
const TypePtr& elementType) {
VELOX_CHECK_EQ(inputArgs.size(), 2);
const BaseVector* rhs = inputArgs[1].constantValue.get();

if (elementType->providesCustomComparison()) {
return createTypedArraysIntersectExcept<isIntersect, WrappedVectorEntry>(
rhs);
} else {
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>();
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
WrappedVectorEntry>;
return createTypedArraysIntersectExcept<isIntersect, T>(rhs);
}
}

Expand All @@ -498,7 +515,8 @@ std::shared_ptr<exec::VectorFunction> createArrayIntersect(
createTypedArraysIntersectExcept,
/* isIntersect */ true,
elementType->kind(),
inputArgs);
inputArgs,
elementType);
}

std::shared_ptr<exec::VectorFunction> createArrayExcept(
Expand All @@ -512,7 +530,8 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
createTypedArraysIntersectExcept,
/* isIntersect */ false,
elementType->kind(),
inputArgs);
inputArgs,
elementType);
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
Expand All @@ -533,19 +552,33 @@ const std::shared_ptr<exec::VectorFunction> createTypedArraysOverlap(
VELOX_CHECK_EQ(inputArgs.size(), 2);
auto left = inputArgs[0].constantValue.get();
auto right = inputArgs[1].constantValue.get();
bool usesCustomComparison =
inputArgs[0].type->childAt(0)->providesCustomComparison();
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
ComplexTypeEntry>;
WrappedVectorEntry>;

if (left == nullptr && right == nullptr) {
return std::make_shared<ArraysOverlapFunction<T>>();
if (usesCustomComparison) {
return std::make_shared<ArraysOverlapFunction<WrappedVectorEntry>>();
} else {
return std::make_shared<ArraysOverlapFunction<T>>();
}
}
auto isLeftConstant = (left != nullptr);
auto baseVector = isLeftConstant ? left : right;
auto constantSet = validateConstantVectorAndGenerateSet<T>(baseVector);
return std::make_shared<ArraysOverlapFunction<T>>(
std::move(constantSet), isLeftConstant);

if (usesCustomComparison) {
auto constantSet =
validateConstantVectorAndGenerateSet<WrappedVectorEntry>(baseVector);
return std::make_shared<ArraysOverlapFunction<WrappedVectorEntry>>(
std::move(constantSet), isLeftConstant);
} else {
auto constantSet = validateConstantVectorAndGenerateSet<T>(baseVector);
return std::make_shared<ArraysOverlapFunction<T>>(
std::move(constantSet), isLeftConstant);
}
}

std::shared_ptr<exec::VectorFunction> createArraysOverlapFunction(
Expand Down
101 changes: 101 additions & 0 deletions velox/functions/prestosql/tests/ArrayExceptTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <optional>
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h"

using namespace facebook::velox;
Expand Down Expand Up @@ -394,3 +395,103 @@ TEST_F(ArrayExceptTest, dictionaryEncodedElementsInConstant) {
"array_except(c0, testing_dictionary_array_elements(ARRAY [0, 1, 3, 2, 2, 3, 2]))",
{array});
}

TEST_F(ArrayExceptTest, timestampWithTimezone) {
auto testArrayExcept =
[this](
const std::vector<std::optional<int64_t>>& inputArray1,
const std::vector<std::optional<int64_t>>& inputArray2,
const std::vector<std::optional<int64_t>>& expectedArrayForward,
const std::vector<std::optional<int64_t>>& expectedArrayBackward) {
const auto input1 = makeArrayVector(
{0},
makeNullableFlatVector(inputArray1, TIMESTAMP_WITH_TIME_ZONE()));
const auto input2 = makeArrayVector(
{0},
makeNullableFlatVector(inputArray2, TIMESTAMP_WITH_TIME_ZONE()));
const auto expectedForward = makeArrayVector(
{0},
makeNullableFlatVector(
expectedArrayForward, TIMESTAMP_WITH_TIME_ZONE()));
const auto expectedBackward = makeArrayVector(
{0},
makeNullableFlatVector(
expectedArrayBackward, TIMESTAMP_WITH_TIME_ZONE()));

testExpr(expectedForward, "array_except(c0, c1)", {input1, input2});
testExpr(expectedBackward, "array_except(c1, c0)", {input1, input2});
};

testArrayExcept(
{pack(1, 0),
pack(-2, 1),
pack(3, 2),
std::nullopt,
pack(4, 3),
pack(5, 4),
pack(6, 5),
std::nullopt},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(3, 2), std::nullopt, pack(5, 4), pack(6, 5)},
{});
testArrayExcept(
{pack(1, 0), pack(2, 1), pack(-2, 2), pack(1, 3)},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(2, 1)},
{pack(4, 12)});
testArrayExcept(
{pack(3, 0), pack(8, 1), std::nullopt},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(3, 0), pack(8, 1), std::nullopt},
{pack(1, 10), pack(-2, 11), pack(4, 12)});
testArrayExcept(
{pack(1, 0),
pack(1, 1),
pack(-2, 2),
pack(-2, 3),
pack(-2, 4),
pack(4, 5),
pack(8, 6)},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(8, 6)},
{});
testArrayExcept(
{pack(1, 0),
pack(-2, 1),
pack(3, 2),
std::nullopt,
pack(4, 3),
pack(5, 4),
pack(6, 5),
std::nullopt},
{pack(10, 10), pack(-24, 11), pack(43, 12)},
{pack(1, 0),
pack(-2, 1),
pack(3, 2),
std::nullopt,
pack(4, 3),
pack(5, 4),
pack(6, 5)},
{pack(10, 10), pack(-24, 11), pack(43, 12)});
testArrayExcept(
{pack(1, 0), pack(2, 1), pack(-2, 2), pack(1, 3)},
{std::nullopt, pack(-2, 10), pack(2, 11)},
{pack(1, 0)},
{std::nullopt});
testArrayExcept(
{pack(3, 0), pack(8, 1), std::nullopt},
{std::nullopt, std::nullopt, std::nullopt},
{pack(3, 0), pack(8, 1)},
{});
testArrayExcept(
{pack(1, 0),
pack(1, 1),
pack(-2, 2),
pack(-2, 3),
pack(-2, 4),
pack(4, 5),
pack(8, 6)},
{pack(8, 10), pack(1, 11), pack(8, 12), pack(1, 13)},
{pack(-2, 2), pack(4, 5)},
{});
}
95 changes: 95 additions & 0 deletions velox/functions/prestosql/tests/ArrayIntersectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <optional>
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h"

using namespace facebook::velox;
Expand Down Expand Up @@ -388,3 +389,97 @@ TEST_F(ArrayIntersectTest, dictionaryEncodedElementsInConstant) {
"array_intersect(c0, testing_dictionary_array_elements(ARRAY [2, 2, 3, 1, 2, 2]))",
{array});
}

TEST_F(ArrayIntersectTest, timestampWithTimezone) {
auto testArrayIntersect =
[this](
const std::vector<std::optional<int64_t>>& inputArray1,
const std::vector<std::optional<int64_t>>& inputArray2,
const std::vector<std::optional<int64_t>>& expectedArrayForward,
const std::vector<std::optional<int64_t>>& expectedArrayBackward) {
const auto input1 = makeArrayVector(
{0},
makeNullableFlatVector(inputArray1, TIMESTAMP_WITH_TIME_ZONE()));
const auto input2 = makeArrayVector(
{0},
makeNullableFlatVector(inputArray2, TIMESTAMP_WITH_TIME_ZONE()));
const auto expectedForward = makeArrayVector(
{0},
makeNullableFlatVector(
expectedArrayForward, TIMESTAMP_WITH_TIME_ZONE()));
const auto expectedBackward = makeArrayVector(
{0},
makeNullableFlatVector(
expectedArrayBackward, TIMESTAMP_WITH_TIME_ZONE()));

testExpr(expectedForward, "array_intersect(c0, c1)", {input1, input2});
testExpr(expectedBackward, "array_intersect(c1, c0)", {input1, input2});
};

testArrayIntersect(
{pack(1, 0),
pack(-2, 1),
pack(3, 2),
std::nullopt,
pack(4, 3),
pack(5, 4),
pack(6, 5),
std::nullopt},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(1, 0), pack(-2, 1), pack(4, 3)},
{pack(1, 10), pack(-2, 11), pack(4, 12)});
testArrayIntersect(
{pack(1, 0), pack(2, 1), pack(-2, 2), pack(1, 3)},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(1, 0), pack(-2, 2)},
{pack(1, 10), pack(-2, 11)});
testArrayIntersect(
{pack(3, 0), pack(8, 1), std::nullopt},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{},
{});
testArrayIntersect(
{pack(1, 0),
pack(1, 1),
pack(-2, 2),
pack(-2, 3),
pack(-2, 4),
pack(4, 5),
pack(8, 6)},
{pack(1, 10), pack(-2, 11), pack(4, 12)},
{pack(1, 0), pack(-2, 2), pack(4, 5)},
{pack(1, 10), pack(-2, 11), pack(4, 12)});
testArrayIntersect(
{pack(1, 0),
pack(-2, 1),
pack(3, 2),
std::nullopt,
pack(4, 3),
pack(5, 4),
pack(6, 5),
std::nullopt},
{pack(10, 10), pack(-24, 11), pack(43, 12)},
{},
{});
testArrayIntersect(
{pack(1, 0), pack(2, 1), pack(-2, 2), pack(1, 3)},
{std::nullopt, pack(-2, 10), pack(2, 11)},
{pack(2, 1), pack(-2, 2)},
{pack(-2, 10), pack(2, 11)});
testArrayIntersect(
{pack(3, 0), pack(8, 1), std::nullopt},
{std::nullopt, std::nullopt, std::nullopt},
{std::nullopt},
{std::nullopt});
testArrayIntersect(
{pack(1, 0),
pack(1, 1),
pack(-2, 2),
pack(-2, 3),
pack(-2, 4),
pack(4, 5),
pack(8, 6)},
{pack(8, 10), pack(1, 11), pack(8, 12), pack(1, 13)},
{pack(1, 0), pack(8, 6)},
{pack(8, 10), pack(1, 11)});
}
Loading

0 comments on commit a6899b7

Please sign in to comment.