Skip to content

Commit

Permalink
Add decimal support for min_by and max_by functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Karteekmurthys committed Feb 16, 2024
1 parent 7b68a82 commit 0f1bcd4
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 29 deletions.
13 changes: 9 additions & 4 deletions velox/exec/tests/utils/QueryAssertions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,16 @@ ::duckdb::Value duckValueAt<TypeKind::HUGEINT>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::HUGEINT>::WrapperType;
auto type = vector->type()->asLongDecimal();
auto val = vector->as<SimpleVector<T>>()->valueAt(index);
auto duckVal = ::duckdb::hugeint_t();
duckVal.lower = (val << 64) >> 64;
duckVal.upper = (val >> 64);
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
if (vector->type()->isLongDecimal()) {
auto type = vector->type()->asLongDecimal();
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
}
// Flat vector is HUGEINT type and not the logical decimal type.
return ::duckdb::Value::HUGEINT(duckVal);
}

template <>
Expand Down Expand Up @@ -252,8 +256,9 @@ velox::variant variantAt<TypeKind::HUGEINT>(
::duckdb::DataChunk* dataChunk,
int32_t row,
int32_t column) {
auto hugeInt = ::duckdb::HugeIntValue::Get(dataChunk->GetValue(column, row));
return velox::variant(HugeInt::build(hugeInt.upper, hugeInt.lower));
auto unscaledValue =
dataChunk->GetValue(column, row).GetValue<::duckdb::hugeint_t>();
return variant(HugeInt::build(unscaledValue.upper, unscaledValue.lower));
}

template <>
Expand Down
6 changes: 6 additions & 0 deletions velox/functions/lib/aggregates/MinMaxByAggregatesBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,9 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::DOUBLE:
return std::make_unique<Aggregate<W, double, isMaxFunc, Comparator>>(
resultType);
case TypeKind::HUGEINT:
return std::make_unique<Aggregate<W, int128_t, isMaxFunc, Comparator>>(
resultType);
case TypeKind::VARBINARY:
[[fallthrough]];
case TypeKind::VARCHAR:
Expand Down Expand Up @@ -634,6 +637,9 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::BIGINT:
return create<Aggregate, isMaxFunc, Comparator, int64_t>(
resultType, compareType, errorMessage, throwOnNestedNulls);
case TypeKind::HUGEINT:
return create<Aggregate, isMaxFunc, Comparator, int128_t>(
resultType, compareType, errorMessage);
case TypeKind::REAL:
return create<Aggregate, isMaxFunc, Comparator, float>(
resultType, compareType, errorMessage, throwOnNestedNulls);
Expand Down
53 changes: 50 additions & 3 deletions velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,16 @@ struct MinMaxByNAccumulator {
int64_t n{0};

using Pair = std::pair<C, std::optional<V>>;
using Allocator = std::conditional_t<
std::is_same_v<int128_t, V> || std::is_same_v<int128_t, C>,
AlignedStlAllocator<Pair, sizeof(int128_t)>,
StlAllocator<Pair>>;
using Queue =
std::priority_queue<Pair, std::vector<Pair, StlAllocator<Pair>>, Compare>;
std::priority_queue<Pair, std::vector<Pair, Allocator>, Compare>;
Queue topPairs;

explicit MinMaxByNAccumulator(HashStringAllocator* allocator)
: topPairs{Compare{}, StlAllocator<Pair>(allocator)} {}
: topPairs{Compare{}, Allocator(allocator)} {}

int64_t getN() const {
return n;
Expand Down Expand Up @@ -967,6 +971,16 @@ class MinByNAggregate : public MinMaxByNAggregate<V, C, Less<V, C>> {
: MinMaxByNAggregate<V, C, Less<V, C>>(resultType) {}
};

template <typename V>
class LongDecimalMinByNAggregate : public MinByNAggregate<V, int128_t> {
public:
explicit LongDecimalMinByNAggregate(TypePtr resultType)
: MinByNAggregate<V, int128_t>(resultType) {}
int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
}
};

template <typename C>
class MinByNAggregate<ComplexType, C>
: public MinMaxByNAggregate<
Expand All @@ -988,6 +1002,16 @@ class MaxByNAggregate : public MinMaxByNAggregate<V, C, Greater<V, C>> {
: MinMaxByNAggregate<V, C, Greater<V, C>>(resultType) {}
};

template <typename V>
class LongDecimalMaxByNAggregate : public MaxByNAggregate<V, int128_t> {
public:
explicit LongDecimalMaxByNAggregate(TypePtr resultType)
: MaxByNAggregate<V, int128_t>(resultType) {}
int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
}
};

template <typename C>
class MaxByNAggregate<ComplexType, C>
: public MinMaxByNAggregate<
Expand Down Expand Up @@ -1018,6 +1042,13 @@ std::unique_ptr<exec::Aggregate> createNArg(
return std::make_unique<NAggregate<W, int32_t>>(resultType);
case TypeKind::BIGINT:
return std::make_unique<NAggregate<W, int64_t>>(resultType);
case TypeKind::HUGEINT:
if constexpr (std::is_same_v<
NAggregate<W, int128_t>,
MinByNAggregate<W, int128_t>>) {
return std::make_unique<LongDecimalMinByNAggregate<W>>(resultType);
}
return std::make_unique<LongDecimalMaxByNAggregate<W>>(resultType);
case TypeKind::REAL:
return std::make_unique<NAggregate<W, float>>(resultType);
case TypeKind::DOUBLE:
Expand Down Expand Up @@ -1054,6 +1085,9 @@ std::unique_ptr<exec::Aggregate> createNArg(
case TypeKind::BIGINT:
return createNArg<NAggregate, int64_t>(
resultType, compareType, errorMessage);
case TypeKind::HUGEINT:
return createNArg<NAggregate, int128_t>(
resultType, compareType, errorMessage);
case TypeKind::REAL:
return createNArg<NAggregate, float>(
resultType, compareType, errorMessage);
Expand Down Expand Up @@ -1113,12 +1147,14 @@ exec::AggregateRegistrationResult registerMinMaxBy(const std::string& name) {
.argumentType("V")
.argumentType("C")
.build());
// Add signatures for 3-arg version of min_by/max_by.
const std::vector<std::string> supportedCompareTypes = {
"boolean",
"tinyint",
"smallint",
"integer",
"bigint",
"hugeint",
"real",
"double",
"varchar",
Expand All @@ -1138,7 +1174,18 @@ exec::AggregateRegistrationResult registerMinMaxBy(const std::string& name) {
.argumentType("bigint")
.build());
}

signatures.push_back(
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.typeVariable("V")
.returnType("array(V)")
.intermediateType(
"row(bigint,array(DECIMAL(a_precision, a_scale)),array(V))")
.argumentType("V")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("bigint")
.build());
return exec::registerAggregateFunction(
name,
std::move(signatures),
Expand Down
Loading

0 comments on commit 0f1bcd4

Please sign in to comment.