From 0f1bcd4a7b0cedb3652c0aae8517fbace7e752ec Mon Sep 17 00:00:00 2001 From: Karteekmurthys Date: Wed, 7 Feb 2024 21:30:26 -0800 Subject: [PATCH] Add decimal support for min_by and max_by functions --- velox/exec/tests/utils/QueryAssertions.cpp | 13 +- .../lib/aggregates/MinMaxByAggregatesBase.h | 6 + .../aggregates/MinMaxByAggregates.cpp | 53 ++++- .../tests/MinMaxByAggregationTest.cpp | 200 ++++++++++++++++-- velox/type/Type.cpp | 2 + velox/type/Variant.cpp | 8 +- velox/type/Variant.h | 3 +- 7 files changed, 256 insertions(+), 29 deletions(-) diff --git a/velox/exec/tests/utils/QueryAssertions.cpp b/velox/exec/tests/utils/QueryAssertions.cpp index 0a4ebdf3ccfd..f84be11f87bc 100644 --- a/velox/exec/tests/utils/QueryAssertions.cpp +++ b/velox/exec/tests/utils/QueryAssertions.cpp @@ -131,12 +131,16 @@ ::duckdb::Value duckValueAt( const VectorPtr& vector, vector_size_t index) { using T = typename KindToFlatVector::WrapperType; - auto type = vector->type()->asLongDecimal(); auto val = vector->as>()->valueAt(index); auto duckVal = ::duckdb::hugeint_t(); duckVal.lower = (val << 64) >> 64; duckVal.upper = (val >> 64); - return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale()); + if (vector->type()->isLongDecimal()) { + auto type = vector->type()->asLongDecimal(); + return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale()); + } + // Flat vector is HUGEINT type and not the logical decimal type. + return ::duckdb::Value::HUGEINT(duckVal); } template <> @@ -252,8 +256,9 @@ velox::variant variantAt( ::duckdb::DataChunk* dataChunk, int32_t row, int32_t column) { - auto hugeInt = ::duckdb::HugeIntValue::Get(dataChunk->GetValue(column, row)); - return velox::variant(HugeInt::build(hugeInt.upper, hugeInt.lower)); + auto unscaledValue = + dataChunk->GetValue(column, row).GetValue<::duckdb::hugeint_t>(); + return variant(HugeInt::build(unscaledValue.upper, unscaledValue.lower)); } template <> diff --git a/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h b/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h index 64e4c7cf6913..712ebf39384d 100644 --- a/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h +++ b/velox/functions/lib/aggregates/MinMaxByAggregatesBase.h @@ -580,6 +580,9 @@ std::unique_ptr create( case TypeKind::DOUBLE: return std::make_unique>( resultType); + case TypeKind::HUGEINT: + return std::make_unique>( + resultType); case TypeKind::VARBINARY: [[fallthrough]]; case TypeKind::VARCHAR: @@ -634,6 +637,9 @@ std::unique_ptr create( case TypeKind::BIGINT: return create( resultType, compareType, errorMessage, throwOnNestedNulls); + case TypeKind::HUGEINT: + return create( + resultType, compareType, errorMessage); case TypeKind::REAL: return create( resultType, compareType, errorMessage, throwOnNestedNulls); diff --git a/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp index 877134ceff88..8a932ecde000 100644 --- a/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp @@ -89,12 +89,16 @@ struct MinMaxByNAccumulator { int64_t n{0}; using Pair = std::pair>; + using Allocator = std::conditional_t< + std::is_same_v || std::is_same_v, + AlignedStlAllocator, + StlAllocator>; using Queue = - std::priority_queue>, Compare>; + std::priority_queue, Compare>; Queue topPairs; explicit MinMaxByNAccumulator(HashStringAllocator* allocator) - : topPairs{Compare{}, StlAllocator(allocator)} {} + : topPairs{Compare{}, Allocator(allocator)} {} int64_t getN() const { return n; @@ -967,6 +971,16 @@ class MinByNAggregate : public MinMaxByNAggregate> { : MinMaxByNAggregate>(resultType) {} }; +template +class LongDecimalMinByNAggregate : public MinByNAggregate { + public: + explicit LongDecimalMinByNAggregate(TypePtr resultType) + : MinByNAggregate(resultType) {} + int32_t accumulatorAlignmentSize() const override { + return static_cast(sizeof(int128_t)); + } +}; + template class MinByNAggregate : public MinMaxByNAggregate< @@ -988,6 +1002,16 @@ class MaxByNAggregate : public MinMaxByNAggregate> { : MinMaxByNAggregate>(resultType) {} }; +template +class LongDecimalMaxByNAggregate : public MaxByNAggregate { + public: + explicit LongDecimalMaxByNAggregate(TypePtr resultType) + : MaxByNAggregate(resultType) {} + int32_t accumulatorAlignmentSize() const override { + return static_cast(sizeof(int128_t)); + } +}; + template class MaxByNAggregate : public MinMaxByNAggregate< @@ -1018,6 +1042,13 @@ std::unique_ptr createNArg( return std::make_unique>(resultType); case TypeKind::BIGINT: return std::make_unique>(resultType); + case TypeKind::HUGEINT: + if constexpr (std::is_same_v< + NAggregate, + MinByNAggregate>) { + return std::make_unique>(resultType); + } + return std::make_unique>(resultType); case TypeKind::REAL: return std::make_unique>(resultType); case TypeKind::DOUBLE: @@ -1054,6 +1085,9 @@ std::unique_ptr createNArg( case TypeKind::BIGINT: return createNArg( resultType, compareType, errorMessage); + case TypeKind::HUGEINT: + return createNArg( + resultType, compareType, errorMessage); case TypeKind::REAL: return createNArg( resultType, compareType, errorMessage); @@ -1113,12 +1147,14 @@ exec::AggregateRegistrationResult registerMinMaxBy(const std::string& name) { .argumentType("V") .argumentType("C") .build()); + // Add signatures for 3-arg version of min_by/max_by. const std::vector supportedCompareTypes = { "boolean", "tinyint", "smallint", "integer", "bigint", + "hugeint", "real", "double", "varchar", @@ -1138,7 +1174,18 @@ exec::AggregateRegistrationResult registerMinMaxBy(const std::string& name) { .argumentType("bigint") .build()); } - + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .typeVariable("V") + .returnType("array(V)") + .intermediateType( + "row(bigint,array(DECIMAL(a_precision, a_scale)),array(V))") + .argumentType("V") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("bigint") + .build()); return exec::registerAggregateFunction( name, std::move(signatures), diff --git a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp index 5ac1386527e6..0a818c5ad7b3 100644 --- a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp @@ -42,6 +42,7 @@ const std::unordered_set kSupportedTypes = { TypeKind::SMALLINT, TypeKind::INTEGER, TypeKind::BIGINT, + TypeKind::HUGEINT, TypeKind::REAL, TypeKind::DOUBLE, TypeKind::VARCHAR, @@ -75,6 +76,9 @@ std::vector getTestParams() { case TypeKind::BIGINT: \ testFunc(); \ break; \ + case TypeKind::HUGEINT: \ + testFunc(); \ + break; \ case TypeKind::REAL: \ testFunc(); \ break; \ @@ -111,6 +115,9 @@ std::vector getTestParams() { case TypeKind::BIGINT: \ EXECUTE_TEST_BY_VALUE_TYPE(testFunc, int64_t); \ break; \ + case TypeKind::HUGEINT: \ + EXECUTE_TEST_BY_VALUE_TYPE(testFunc, int128_t); \ + break; \ case TypeKind::REAL: \ EXECUTE_TEST_BY_VALUE_TYPE(testFunc, float); \ break; \ @@ -189,27 +196,31 @@ class MinMaxByAggregationTestBase : public AggregationTestBase { vector_size_t size, folly::Range values); - const RowTypePtr rowType_{ - ROW({"c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9"}, - { - BOOLEAN(), - TINYINT(), - SMALLINT(), - INTEGER(), - BIGINT(), - REAL(), - DOUBLE(), - VARCHAR(), - DATE(), - TIMESTAMP(), - })}; + RowTypePtr rowType_; + RowTypePtr rowTypeWithHugeInt_; // Specify the number of values in each typed data vector in // 'dataVectorsByType_'. const int numValues_; std::unordered_map dataVectorsByType_; std::vector rowVectors_; + std::vector rowVectorsWithHugeInt_; }; +template <> +FlatVectorPtr MinMaxByAggregationTestBase::buildDataVector( + vector_size_t size, + folly::Range values) { + if (values.empty()) { + return makeFlatVector( + size, [](auto row) { return HugeInt::build(row - 3, row - 3); }); + } else { + VELOX_CHECK_EQ(values.size(), size); + return makeFlatVector(size, [&](auto row) { + return HugeInt::build(values[row], values[row]); + }); + } +} + // Build a flat vector with StringView. The value in the returned flat vector // is in ascending order. template <> @@ -277,6 +288,8 @@ VectorPtr MinMaxByAggregationTestBase::buildDataVector( return buildDataVector(size, values); case TypeKind::BIGINT: return buildDataVector(size, values); + case TypeKind::HUGEINT: + return buildDataVector(size, values); case TypeKind::REAL: return buildDataVector(size, values); case TypeKind::DOUBLE: @@ -309,6 +322,29 @@ std::string asSql(bool value) { void MinMaxByAggregationTestBase::SetUp() { AggregationTestBase::SetUp(); AggregationTestBase::disallowInputShuffle(); + std::vector columnNames{ + "c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9"}; + std::vector typePtrs{ + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + DATE(), + TIMESTAMP()}; + auto colNamesCopy(columnNames); + auto typePtrsCopy(typePtrs); + rowType_ = { + ROW(std::forward&&>(columnNames), + std::forward&&>(typePtrs))}; + colNamesCopy.push_back("c10"); + typePtrsCopy.push_back(HUGEINT()); + rowTypeWithHugeInt_ = { + ROW(std::forward&&>(colNamesCopy), + std::forward&&>(typePtrsCopy))}; for (const TypeKind type : kSupportedTypes) { switch (type) { @@ -327,6 +363,9 @@ void MinMaxByAggregationTestBase::SetUp() { case TypeKind::BIGINT: dataVectorsByType_.emplace(type, buildDataVector(numValues_)); break; + case TypeKind::HUGEINT: + dataVectorsByType_.emplace(type, buildDataVector(numValues_)); + break; case TypeKind::REAL: dataVectorsByType_.emplace(type, buildDataVector(numValues_)); break; @@ -348,6 +387,8 @@ void MinMaxByAggregationTestBase::SetUp() { ASSERT_EQ(dataVectorsByType_.size(), kSupportedTypes.size()); rowVectors_ = makeVectors(rowType_, 5, 10); createDuckDbTable(rowVectors_); + rowVectorsWithHugeInt_ = makeVectors(rowTypeWithHugeInt_, 5, 10); + createDuckDbTable(rowVectorsWithHugeInt_); }; class MinMaxByGlobalByAggregationTest @@ -361,7 +402,8 @@ class MinMaxByGlobalByAggregationTest const std::vector& vectors, const std::string& aggName, const std::string& valueColumnName, - const std::string& comparisonColumnName) { + const std::string& comparisonColumnName, + bool testWithTableScan = true) { const std::string funcName = aggName == kMaxBy ? "max" : "min"; const std::string verifyDuckDbSql = fmt::format( "SELECT {} FROM tmp WHERE {} = ( SELECT {} ({}) FROM tmp) LIMIT 1", @@ -373,7 +415,8 @@ class MinMaxByGlobalByAggregationTest "{}({}, {})", aggName, valueColumnName, comparisonColumnName); SCOPED_TRACE( fmt::format("{}\nverifyDuckDbSql: {}", aggregate, verifyDuckDbSql)); - testAggregations(vectors, {}, {aggregate}, {}, verifyDuckDbSql); + testAggregations( + vectors, {}, {aggregate}, {}, verifyDuckDbSql, {}, testWithTableScan); } template @@ -573,6 +616,68 @@ class MinMaxByGlobalByAggregationTest } }; +TEST_F(MinMaxByGlobalByAggregationTest, decimalSignatureTest) { + auto decimalType = DECIMAL(38, 4); + auto rowType = ROW({"c0", "c1"}, {decimalType, decimalType}); + auto rowVector = makeRowVector( + {makeNullableFlatVector( + {dataAt(0), + dataAt(1), + dataAt(2), + dataAt(3), + dataAt(4), + std::nullopt}, + decimalType), + makeNullableFlatVector( + {dataAt(5), + dataAt(4), + dataAt(3), + dataAt(2), + dataAt(1), + dataAt(0)}, + decimalType)}); + char delim = ','; + + // Decimal value type and compare type. + testAggregations( + {rowVector}, + {}, + {"min_by(c0, c1)", "max_by(c0, c1)"}, + {}, + {makeRowVector( + {makeNullableFlatVector({std::nullopt}, decimalType), + makeNullableFlatVector( + {dataAt(0)}, decimalType)})}, + {}, + false); + // Incremental Aggregation checks consistency of intermediate results by + // calling extractAccumulators twice. But, for 3 parameters min_by/max_by + // aggregate function, extractAccumulator cleans up the priority queue holding + // the result and the second call will return empty result. Thus sub-test is + // not applicable. + this->disableTestIncremental(); + testAggregations( + {rowVector}, + {}, + {"min_by(c0, c1, 4)", "max_by(c0, c1, 4)"}, + {}, + {makeRowVector( + {makeNullableArrayVector( + {{std::nullopt, + dataAt(4), + dataAt(3), + dataAt(2)}}, + ARRAY(decimalType)), + makeNullableArrayVector( + {{dataAt(0), + dataAt(1), + dataAt(2), + dataAt(3)}}, + ARRAY(decimalType))})}, + {}, + false); +} + TEST_P(MinMaxByGlobalByAggregationTest, minByFinalGlobalBy) { EXECUTE_TEST(minByGlobalByTest); } @@ -593,7 +698,24 @@ TEST_P(MinMaxByGlobalByAggregationTest, randomMinByGlobalBy) { GetParam().valueType == TypeKind::TIMESTAMP) { return; } - + if (GetParam().comparisonType == TypeKind::HUGEINT || + GetParam().valueType == TypeKind::HUGEINT) { + // Plans with tablescan for HUGEINT fail + auto valColName = (GetParam().valueType == TypeKind::HUGEINT) + ? "c10" + : getColumnName(GetParam().valueType); + auto compColName = (GetParam().comparisonType == TypeKind::HUGEINT) + ? "c10" + : getColumnName(GetParam().comparisonType); + + testGlobalAggregation( + rowVectorsWithHugeInt_, + kMinBy, + valColName, + compColName, + false /*Skip tablescan tests*/); + return; + } testGlobalAggregation( rowVectors_, kMinBy, @@ -607,6 +729,25 @@ TEST_P(MinMaxByGlobalByAggregationTest, randomMaxByGlobalBy) { return; } + if (GetParam().comparisonType == TypeKind::HUGEINT || + GetParam().valueType == TypeKind::HUGEINT) { + // Plans with tablescan for HUGEINT fail + auto valColName = (GetParam().valueType == TypeKind::HUGEINT) + ? "c10" + : getColumnName(GetParam().valueType); + auto compColName = (GetParam().comparisonType == TypeKind::HUGEINT) + ? "c10" + : getColumnName(GetParam().comparisonType); + + testGlobalAggregation( + rowVectorsWithHugeInt_, + kMaxBy, + valColName, + compColName, + false /*Skip tablescan tests*/); + return; + } + testGlobalAggregation( rowVectors_, kMaxBy, @@ -623,6 +764,11 @@ TEST_P( return; } + bool runTableScanTests = true; + if (GetParam().comparisonType == TypeKind::HUGEINT || + GetParam().valueType == TypeKind::HUGEINT) { + runTableScanTests = false; + } // Enable disk spilling test with distinct comparison values. AggregationTestBase::allowInputShuffle(); @@ -662,9 +808,9 @@ TEST_P( } createDuckDbTable(rowVectors); - testGlobalAggregation(rowVectors, kMinBy, "c0", "c1"); + testGlobalAggregation(rowVectors, kMinBy, "c0", "c1", runTableScanTests); - testGlobalAggregation(rowVectors, kMaxBy, "c0", "c1"); + testGlobalAggregation(rowVectors, kMaxBy, "c0", "c1", runTableScanTests); } VELOX_INSTANTIATE_TEST_SUITE_P( @@ -1032,6 +1178,11 @@ TEST_P(MinMaxByGroupByAggregationTest, randomMinByGroupBy) { return; } + if (GetParam().comparisonType == TypeKind::HUGEINT || + GetParam().valueType == TypeKind::HUGEINT) { + GTEST_SKIP() + << "HUGEINT value type is skipped. DUCKDB min_by/max_by for HUGEINT are cast to double."; + } testGroupByAggregation( rowVectors_, kMinBy, @@ -1046,6 +1197,12 @@ TEST_P(MinMaxByGroupByAggregationTest, randomMaxByGroupBy) { return; } + if (GetParam().comparisonType == TypeKind::HUGEINT || + GetParam().valueType == TypeKind::HUGEINT) { + GTEST_SKIP() + << "HUGEINT value type is skipped. DUCKDB min_by/max_by for HUGEINT are cast to double."; + } + testGroupByAggregation( rowVectors_, kMaxBy, @@ -1063,6 +1220,11 @@ TEST_P( return; } + if (GetParam().comparisonType == TypeKind::HUGEINT || + GetParam().valueType == TypeKind::HUGEINT) { + GTEST_SKIP() + << "HUGEINT value type is skipped. DUCKDB min_by/max_by for HUGEINT are cast to double."; + } // Enable disk spilling test with distinct comparison values. AggregationTestBase::allowInputShuffle(); diff --git a/velox/type/Type.cpp b/velox/type/Type.cpp index 921536dbf400..e8c0503a017a 100644 --- a/velox/type/Type.cpp +++ b/velox/type/Type.cpp @@ -900,6 +900,8 @@ TypePtr fromKindToScalerType(TypeKind kind) { return SMALLINT(); case TypeKind::BIGINT: return BIGINT(); + case TypeKind::HUGEINT: + return HUGEINT(); case TypeKind::INTEGER: return INTEGER(); case TypeKind::REAL: diff --git a/velox/type/Variant.cpp b/velox/type/Variant.cpp index 1f15cff4c6fb..fd5bd18c721d 100644 --- a/velox/type/Variant.cpp +++ b/velox/type/Variant.cpp @@ -272,8 +272,10 @@ std::string variant::toJson(const TypePtr& type) const { return target; } case TypeKind::HUGEINT: { - VELOX_CHECK(type->isLongDecimal()) { + if (type && type->isLongDecimal()) { return DecimalUtil::toString(value(), type); + } else { + return std::to_string(value()); } } case TypeKind::TINYINT: @@ -399,8 +401,10 @@ std::string variant::toJsonUnsafe(const TypePtr& type) const { return target; } case TypeKind::HUGEINT: { - VELOX_CHECK(type && type->isLongDecimal()) { + if (type && type->isLongDecimal()) { return DecimalUtil::toString(value(), type); + } else { + return std::to_string(value()); } } case TypeKind::TINYINT: diff --git a/velox/type/Variant.h b/velox/type/Variant.h index 0f1451fffab2..69aa483b8833 100644 --- a/velox/type/Variant.h +++ b/velox/type/Variant.h @@ -603,8 +603,9 @@ struct VariantConverter { return convert(value); case TypeKind::VARBINARY: return convert(value); - case TypeKind::TIMESTAMP: case TypeKind::HUGEINT: + return convert(value); + case TypeKind::TIMESTAMP: // Default date/timestamp conversion is prone to errors and implicit // assumptions. Block converting timestamp to integer, double and // std::string types. The callers should implement their own conversion