diff --git a/velox/exec/tests/utils/QueryAssertions.cpp b/velox/exec/tests/utils/QueryAssertions.cpp index c271b0d4a49f..b55da96d8033 100644 --- a/velox/exec/tests/utils/QueryAssertions.cpp +++ b/velox/exec/tests/utils/QueryAssertions.cpp @@ -132,12 +132,16 @@ ::duckdb::Value duckValueAt( const VectorPtr& vector, vector_size_t index) { using T = typename KindToFlatVector::WrapperType; - auto type = vector->type()->asLongDecimal(); auto val = vector->as>()->valueAt(index); auto duckVal = ::duckdb::hugeint_t(); duckVal.lower = (val << 64) >> 64; duckVal.upper = (val >> 64); - return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale()); + if (vector->type()->isLongDecimal()) { + auto type = vector->type()->asLongDecimal(); + return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale()); + } + // Flat vector is HUGEINT type and not the logical decimal type. + return ::duckdb::Value::HUGEINT(duckVal); } template <> @@ -253,8 +257,9 @@ velox::variant variantAt( ::duckdb::DataChunk* dataChunk, int32_t row, int32_t column) { - auto hugeInt = ::duckdb::HugeIntValue::Get(dataChunk->GetValue(column, row)); - return velox::variant(HugeInt::build(hugeInt.upper, hugeInt.lower)); + auto unscaledValue = + dataChunk->GetValue(column, row).GetValue<::duckdb::hugeint_t>(); + return variant(HugeInt::build(unscaledValue.upper, unscaledValue.lower)); } template <> diff --git a/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h b/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h index 82cf11507e6a..7a98f547c3c7 100644 --- a/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h +++ b/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h @@ -581,6 +581,9 @@ std::unique_ptr create( case TypeKind::DOUBLE: return std::make_unique>( resultType); + case TypeKind::HUGEINT: + return std::make_unique>( + resultType); case TypeKind::VARBINARY: [[fallthrough]]; case TypeKind::VARCHAR: @@ -635,6 +638,9 @@ std::unique_ptr create( case TypeKind::BIGINT: return create( resultType, compareType, errorMessage, throwOnNestedNulls); + case TypeKind::HUGEINT: + return create( + resultType, compareType, errorMessage); case TypeKind::REAL: return create( resultType, compareType, errorMessage, throwOnNestedNulls); diff --git a/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp index 370610bb02aa..a26380f3e412 100644 --- a/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp @@ -90,11 +90,15 @@ struct MinMaxByNAccumulator { int64_t n{0}; using Pair = std::pair>; - using Heap = std::vector>; + using Allocator = std::conditional_t< + std::is_same_v || std::is_same_v, + AlignedStlAllocator, + StlAllocator>; + using Heap = std::vector; Heap heapValues; explicit MinMaxByNAccumulator(HashStringAllocator* allocator) - : heapValues{StlAllocator(allocator)} {} + : heapValues{Allocator(allocator)} {} int64_t getN() const { return n; @@ -990,6 +994,16 @@ class MinByNAggregate : public MinMaxByNAggregate> { : MinMaxByNAggregate>(resultType) {} }; +template +class LongDecimalMinByNAggregate : public MinByNAggregate { + public: + explicit LongDecimalMinByNAggregate(TypePtr resultType) + : MinByNAggregate(resultType) {} + int32_t accumulatorAlignmentSize() const override { + return static_cast(sizeof(int128_t)); + } +}; + template class MinByNAggregate : public MinMaxByNAggregate< @@ -1011,6 +1025,16 @@ class MaxByNAggregate : public MinMaxByNAggregate> { : MinMaxByNAggregate>(resultType) {} }; +template +class LongDecimalMaxByNAggregate : public MaxByNAggregate { + public: + explicit LongDecimalMaxByNAggregate(TypePtr resultType) + : MaxByNAggregate(resultType) {} + int32_t accumulatorAlignmentSize() const override { + return static_cast(sizeof(int128_t)); + } +}; + template class MaxByNAggregate : public MinMaxByNAggregate< @@ -1041,6 +1065,13 @@ std::unique_ptr createNArg( return std::make_unique>(resultType); case TypeKind::BIGINT: return std::make_unique>(resultType); + case TypeKind::HUGEINT: + if constexpr (std::is_same_v< + NAggregate, + MinByNAggregate>) { + return std::make_unique>(resultType); + } + return std::make_unique>(resultType); case TypeKind::REAL: return std::make_unique>(resultType); case TypeKind::DOUBLE: @@ -1077,6 +1108,9 @@ std::unique_ptr createNArg( case TypeKind::BIGINT: return createNArg( resultType, compareType, errorMessage); + case TypeKind::HUGEINT: + return createNArg( + resultType, compareType, errorMessage); case TypeKind::REAL: return createNArg( resultType, compareType, errorMessage); @@ -1139,12 +1173,14 @@ exec::AggregateRegistrationResult registerMinMaxBy( .argumentType("V") .argumentType("C") .build()); + // Add signatures for 3-arg version of min_by/max_by. const std::vector supportedCompareTypes = { "boolean", "tinyint", "smallint", "integer", "bigint", + "hugeint", "real", "double", "varchar", @@ -1164,7 +1200,18 @@ exec::AggregateRegistrationResult registerMinMaxBy( .argumentType("bigint") .build()); } - + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .typeVariable("V") + .returnType("array(V)") + .intermediateType( + "row(bigint,array(DECIMAL(a_precision, a_scale)),array(V))") + .argumentType("V") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("bigint") + .build()); return exec::registerAggregateFunction( name, std::move(signatures), diff --git a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp index 2f5cd38b0fb1..b628c8914993 100644 --- a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp @@ -45,7 +45,8 @@ const std::vector kSupportedTypes = { TypeKind::REAL, TypeKind::DOUBLE, TypeKind::VARCHAR, - TypeKind::TIMESTAMP}; + TypeKind::TIMESTAMP, + TypeKind::HUGEINT}; std::vector getTestParams() { std::vector params; @@ -75,6 +76,9 @@ std::vector getTestParams() { case TypeKind::BIGINT: \ testFunc(); \ break; \ + case TypeKind::HUGEINT: \ + testFunc(); \ + break; \ case TypeKind::REAL: \ testFunc(); \ break; \ @@ -111,6 +115,9 @@ std::vector getTestParams() { case TypeKind::BIGINT: \ EXECUTE_TEST_BY_VALUE_TYPE(testFunc, int64_t); \ break; \ + case TypeKind::HUGEINT: \ + EXECUTE_TEST_BY_VALUE_TYPE(testFunc, int128_t); \ + break; \ case TypeKind::REAL: \ EXECUTE_TEST_BY_VALUE_TYPE(testFunc, float); \ break; \ @@ -203,6 +210,21 @@ class MinMaxByAggregationTestBase : public AggregationTestBase { std::vector rowVectors_; }; +template <> +FlatVectorPtr MinMaxByAggregationTestBase::buildDataVector( + vector_size_t size, + folly::Range values) { + if (values.empty()) { + return makeFlatVector( + size, [](auto row) { return HugeInt::build(row - 3, row - 3); }); + } else { + VELOX_CHECK_EQ(values.size(), size); + return makeFlatVector(size, [&](auto row) { + return HugeInt::build(values[row], values[row]); + }); + } +} + // Build a flat vector with StringView. The value in the returned flat vector // is in ascending order. template <> @@ -274,6 +296,8 @@ VectorPtr MinMaxByAggregationTestBase::buildDataVector( return buildDataVector(size, values); case TypeKind::DOUBLE: return buildDataVector(size, values); + case TypeKind::HUGEINT: + return buildDataVector(size, values); case TypeKind::VARCHAR: return buildDataVector(size, values); case TypeKind::TIMESTAMP: @@ -327,6 +351,9 @@ void MinMaxByAggregationTestBase::SetUp() { case TypeKind::BIGINT: dataVectorsByType_.emplace(type, buildDataVector(numValues_)); break; + case TypeKind::HUGEINT: + dataVectorsByType_.emplace(type, buildDataVector(numValues_)); + break; case TypeKind::REAL: dataVectorsByType_.emplace(type, buildDataVector(numValues_)); break;