Skip to content

Commit

Permalink
Add decimal support for min_by and max_by functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Karteekmurthys authored and karteekmurthys committed Apr 22, 2024
1 parent 28a1f18 commit aa95010
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 8 deletions.
13 changes: 9 additions & 4 deletions velox/exec/tests/utils/QueryAssertions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,16 @@ ::duckdb::Value duckValueAt<TypeKind::HUGEINT>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::HUGEINT>::WrapperType;
auto type = vector->type()->asLongDecimal();
auto val = vector->as<SimpleVector<T>>()->valueAt(index);
auto duckVal = ::duckdb::hugeint_t();
duckVal.lower = (val << 64) >> 64;
duckVal.upper = (val >> 64);
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
if (vector->type()->isLongDecimal()) {
auto type = vector->type()->asLongDecimal();
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
}
// Flat vector is HUGEINT type and not the logical decimal type.
return ::duckdb::Value::HUGEINT(duckVal);
}

template <>
Expand Down Expand Up @@ -253,8 +257,9 @@ velox::variant variantAt<TypeKind::HUGEINT>(
::duckdb::DataChunk* dataChunk,
int32_t row,
int32_t column) {
auto hugeInt = ::duckdb::HugeIntValue::Get(dataChunk->GetValue(column, row));
return velox::variant(HugeInt::build(hugeInt.upper, hugeInt.lower));
auto unscaledValue =
dataChunk->GetValue(column, row).GetValue<::duckdb::hugeint_t>();
return variant(HugeInt::build(unscaledValue.upper, unscaledValue.lower));
}

template <>
Expand Down
6 changes: 6 additions & 0 deletions velox/functions/lib/aggregates/MinMaxByAggregatesBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,9 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::DOUBLE:
return std::make_unique<Aggregate<W, double, isMaxFunc, Comparator>>(
resultType);
case TypeKind::HUGEINT:
return std::make_unique<Aggregate<W, int128_t, isMaxFunc, Comparator>>(
resultType);
case TypeKind::VARBINARY:
[[fallthrough]];
case TypeKind::VARCHAR:
Expand Down Expand Up @@ -635,6 +638,9 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::BIGINT:
return create<Aggregate, isMaxFunc, Comparator, int64_t>(
resultType, compareType, errorMessage, throwOnNestedNulls);
case TypeKind::HUGEINT:
return create<Aggregate, isMaxFunc, Comparator, int128_t>(
resultType, compareType, errorMessage);
case TypeKind::REAL:
return create<Aggregate, isMaxFunc, Comparator, float>(
resultType, compareType, errorMessage, throwOnNestedNulls);
Expand Down
53 changes: 50 additions & 3 deletions velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,15 @@ struct MinMaxByNAccumulator {
int64_t n{0};

using Pair = std::pair<C, std::optional<V>>;
using Heap = std::vector<Pair, StlAllocator<Pair>>;
using Allocator = std::conditional_t<
std::is_same_v<int128_t, V> || std::is_same_v<int128_t, C>,
AlignedStlAllocator<Pair, sizeof(int128_t)>,
StlAllocator<Pair>>;
using Heap = std::vector<Pair, Allocator>;
Heap heapValues;

explicit MinMaxByNAccumulator(HashStringAllocator* allocator)
: heapValues{StlAllocator<Pair>(allocator)} {}
: heapValues{Allocator(allocator)} {}

int64_t getN() const {
return n;
Expand Down Expand Up @@ -990,6 +994,16 @@ class MinByNAggregate : public MinMaxByNAggregate<V, C, Less<V, C>> {
: MinMaxByNAggregate<V, C, Less<V, C>>(resultType) {}
};

template <typename V>
class LongDecimalMinByNAggregate : public MinByNAggregate<V, int128_t> {
public:
explicit LongDecimalMinByNAggregate(TypePtr resultType)
: MinByNAggregate<V, int128_t>(resultType) {}
int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
}
};

template <typename C>
class MinByNAggregate<ComplexType, C>
: public MinMaxByNAggregate<
Expand All @@ -1011,6 +1025,16 @@ class MaxByNAggregate : public MinMaxByNAggregate<V, C, Greater<V, C>> {
: MinMaxByNAggregate<V, C, Greater<V, C>>(resultType) {}
};

template <typename V>
class LongDecimalMaxByNAggregate : public MaxByNAggregate<V, int128_t> {
public:
explicit LongDecimalMaxByNAggregate(TypePtr resultType)
: MaxByNAggregate<V, int128_t>(resultType) {}
int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
}
};

template <typename C>
class MaxByNAggregate<ComplexType, C>
: public MinMaxByNAggregate<
Expand Down Expand Up @@ -1041,6 +1065,13 @@ std::unique_ptr<exec::Aggregate> createNArg(
return std::make_unique<NAggregate<W, int32_t>>(resultType);
case TypeKind::BIGINT:
return std::make_unique<NAggregate<W, int64_t>>(resultType);
case TypeKind::HUGEINT:
if constexpr (std::is_same_v<
NAggregate<W, int128_t>,
MinByNAggregate<W, int128_t>>) {
return std::make_unique<LongDecimalMinByNAggregate<W>>(resultType);
}
return std::make_unique<LongDecimalMaxByNAggregate<W>>(resultType);
case TypeKind::REAL:
return std::make_unique<NAggregate<W, float>>(resultType);
case TypeKind::DOUBLE:
Expand Down Expand Up @@ -1077,6 +1108,9 @@ std::unique_ptr<exec::Aggregate> createNArg(
case TypeKind::BIGINT:
return createNArg<NAggregate, int64_t>(
resultType, compareType, errorMessage);
case TypeKind::HUGEINT:
return createNArg<NAggregate, int128_t>(
resultType, compareType, errorMessage);
case TypeKind::REAL:
return createNArg<NAggregate, float>(
resultType, compareType, errorMessage);
Expand Down Expand Up @@ -1139,12 +1173,14 @@ exec::AggregateRegistrationResult registerMinMaxBy(
.argumentType("V")
.argumentType("C")
.build());
// Add signatures for 3-arg version of min_by/max_by.
const std::vector<std::string> supportedCompareTypes = {
"boolean",
"tinyint",
"smallint",
"integer",
"bigint",
"hugeint",
"real",
"double",
"varchar",
Expand All @@ -1164,7 +1200,18 @@ exec::AggregateRegistrationResult registerMinMaxBy(
.argumentType("bigint")
.build());
}

signatures.push_back(
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.typeVariable("V")
.returnType("array(V)")
.intermediateType(
"row(bigint,array(DECIMAL(a_precision, a_scale)),array(V))")
.argumentType("V")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("bigint")
.build());
return exec::registerAggregateFunction(
name,
std::move(signatures),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ const std::vector<TypeKind> kSupportedTypes = {
TypeKind::REAL,
TypeKind::DOUBLE,
TypeKind::VARCHAR,
TypeKind::TIMESTAMP};
TypeKind::TIMESTAMP,
TypeKind::HUGEINT};

std::vector<TestParam> getTestParams() {
std::vector<TestParam> params;
Expand Down Expand Up @@ -75,6 +76,9 @@ std::vector<TestParam> getTestParams() {
case TypeKind::BIGINT: \
testFunc<valueType, int64_t>(); \
break; \
case TypeKind::HUGEINT: \
testFunc<valueType, int128_t>(); \
break; \
case TypeKind::REAL: \
testFunc<valueType, float>(); \
break; \
Expand Down Expand Up @@ -111,6 +115,9 @@ std::vector<TestParam> getTestParams() {
case TypeKind::BIGINT: \
EXECUTE_TEST_BY_VALUE_TYPE(testFunc, int64_t); \
break; \
case TypeKind::HUGEINT: \
EXECUTE_TEST_BY_VALUE_TYPE(testFunc, int128_t); \
break; \
case TypeKind::REAL: \
EXECUTE_TEST_BY_VALUE_TYPE(testFunc, float); \
break; \
Expand Down Expand Up @@ -203,6 +210,21 @@ class MinMaxByAggregationTestBase : public AggregationTestBase {
std::vector<RowVectorPtr> rowVectors_;
};

template <>
FlatVectorPtr<int128_t> MinMaxByAggregationTestBase::buildDataVector(
vector_size_t size,
folly::Range<const int*> values) {
if (values.empty()) {
return makeFlatVector<int128_t>(
size, [](auto row) { return HugeInt::build(row - 3, row - 3); });
} else {
VELOX_CHECK_EQ(values.size(), size);
return makeFlatVector<int128_t>(size, [&](auto row) {
return HugeInt::build(values[row], values[row]);
});
}
}

// Build a flat vector with StringView. The value in the returned flat vector
// is in ascending order.
template <>
Expand Down Expand Up @@ -274,6 +296,8 @@ VectorPtr MinMaxByAggregationTestBase::buildDataVector(
return buildDataVector<float>(size, values);
case TypeKind::DOUBLE:
return buildDataVector<double>(size, values);
case TypeKind::HUGEINT:
return buildDataVector<int128_t>(size, values);
case TypeKind::VARCHAR:
return buildDataVector<StringView>(size, values);
case TypeKind::TIMESTAMP:
Expand Down Expand Up @@ -327,6 +351,9 @@ void MinMaxByAggregationTestBase::SetUp() {
case TypeKind::BIGINT:
dataVectorsByType_.emplace(type, buildDataVector<int64_t>(numValues_));
break;
case TypeKind::HUGEINT:
dataVectorsByType_.emplace(type, buildDataVector<int128_t>(numValues_));
break;
case TypeKind::REAL:
dataVectorsByType_.emplace(type, buildDataVector<float>(numValues_));
break;
Expand Down

0 comments on commit aa95010

Please sign in to comment.