Skip to content

Commit

Permalink
Add support for custom comparison in Presto's contains UDF (#11227)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11227

Update Presto's contains UDF to work with types that provide custom comparison.
We can reuse the implementation for complex types, since that just uses the compare
function provided by the Vector. With
#11022 this just
invokes the Type's custom implementation.

Reviewed By: xiaoxmeng

Differential Revision: D64211715

fbshipit-source-id: 59bf87a0e586be92455f91e502c9fb90919ebb99
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 11, 2024
1 parent 903ae35 commit c434ed8
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 10 deletions.
35 changes: 25 additions & 10 deletions velox/functions/prestosql/ArrayContains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,31 @@ class ArrayContainsFunction : public exec::VectorFunction {

exec::LocalDecodedVector searchHolder(context, *searchVector, rows);

VELOX_DYNAMIC_TYPE_DISPATCH(
applyTyped,
searchVector->typeKind(),
rows,
*arrayHolder.get(),
*elementsHolder.get(),
*searchHolder.get(),
context,
*flatResult,
throwOnNestedNull_);
if (searchVector->type()->providesCustomComparison()) {
// We use applyComplexType for types that provide custom comparison
// operators because the main difference between applyComplexType and
// applyTyped is that applyComplexType calls the Vector's equalValueAt
// method, which calls the Types custom comparison operator internally.
applyComplexType(
rows,
*arrayHolder.get(),
*elementsHolder.get(),
*searchHolder.get(),
context,
*flatResult,
throwOnNestedNull_);
} else {
VELOX_DYNAMIC_TYPE_DISPATCH(
applyTyped,
searchVector->typeKind(),
rows,
*arrayHolder.get(),
*elementsHolder.get(),
*searchHolder.get(),
context,
*flatResult,
throwOnNestedNull_);
}
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
Expand Down
93 changes: 93 additions & 0 deletions velox/functions/prestosql/tests/ArrayContainsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

using namespace facebook::velox;
using namespace facebook::velox::test;
Expand Down Expand Up @@ -481,4 +482,96 @@ TEST_F(ArrayContainsTest, floatNaNs) {
testFloatingPointNaNs<float>();
testFloatingPointNaNs<double>();
}

TEST_F(ArrayContainsTest, timestampWithTimeZone) {
auto arrayVector = makeArrayVector(
{0, 4, 7, 7, 13, 15},
makeNullableFlatVector<int64_t>(
{pack(1, 1),
pack(2, 2),
pack(3, 3),
pack(4, 4),
pack(3, 5),
pack(4, 6),
pack(5, 7),
pack(5, 8),
pack(6, 9),
std::nullopt,
pack(7, 10),
pack(8, 11),
pack(9, 12),
pack(7, 13),
std::nullopt,
pack(10, 14),
pack(9, 15),
pack(8, 16),
pack(7, 17)},
TIMESTAMP_WITH_TIME_ZONE()));

const auto testContains =
[&](std::optional<int64_t> needle,
const std::vector<std::optional<bool>>& expected) {
const auto searchVector = makeConstant(
needle, arrayVector->size(), TIMESTAMP_WITH_TIME_ZONE());

testContainsGeneric(arrayVector, searchVector, expected);
};

testContains(
pack(1, 1), {true, false, false, std::nullopt, std::nullopt, false});
testContains(
pack(3, 3), {true, true, false, std::nullopt, std::nullopt, false});
testContains(pack(5, 1), {false, true, false, true, std::nullopt, false});
testContains(pack(7, 2), {false, false, false, true, true, true});
testContains(
pack(-2, 1), {false, false, false, std::nullopt, std::nullopt, false});
testContains(
std::nullopt,
{std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt});

// Test wrapped in a complex value.
arrayVector = makeArrayVector(
{0, 4, 7, 7, 12, 13},
makeRowVector({makeNullableFlatVector<int64_t>(
{pack(1, 1),
pack(2, 2),
pack(3, 3),
pack(4, 4),
pack(3, 5),
pack(4, 6),
pack(5, 7),
pack(5, 8),
pack(6, 9),
pack(7, 10),
pack(8, 11),
pack(9, 12),
pack(7, 13),
pack(10, 14),
pack(9, 15),
pack(8, 16),
pack(7, 17)},
TIMESTAMP_WITH_TIME_ZONE())}));

const auto testContainsRow =
[&](int64_t needle, const std::vector<std::optional<bool>>& expected) {
const auto searchVector = BaseVector::wrapInConstant(
arrayVector->size(),
0,
makeRowVector({makeFlatVector(
std::vector<int64_t>{needle}, TIMESTAMP_WITH_TIME_ZONE())}));

testContainsGeneric(arrayVector, searchVector, expected);
};

testContainsRow(pack(1, 1), {true, false, false, false, false, false});
testContainsRow(pack(3, 3), {true, true, false, false, false, false});
testContainsRow(pack(5, 1), {false, true, false, true, false, false});
testContainsRow(pack(7, 2), {false, false, false, true, true, true});
testContainsRow(pack(-2, 1), {false, false, false, false, false, false});
}
} // namespace

0 comments on commit c434ed8

Please sign in to comment.