Skip to content

Commit

Permalink
[fix](decimal256) support decimal256 for many functions (#42136) (#42353
Browse files Browse the repository at this point in the history
)

BP #42136
  • Loading branch information
jacktengg authored Oct 24, 2024
1 parent 32b13b2 commit 64d0c55
Show file tree
Hide file tree
Showing 62 changed files with 949 additions and 214 deletions.
2 changes: 1 addition & 1 deletion be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class RuntimeState {
_query_options.check_overflow_for_decimal;
}

bool enable_decima256() const {
bool enable_decimal256() const {
return _query_options.__isset.enable_decimal256 && _query_options.enable_decimal256;
}

Expand Down
4 changes: 4 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Arena;
class IColumn;
class IDataType;

struct AggregateFunctionAttr {
bool enable_decimal256 {false};
};

template <bool nullable, typename ColVecType>
class AggregateFunctionBitmapCount;
template <typename Op>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType which(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE, COLUMN_TYPE) \
Expand Down
15 changes: 12 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,17 @@ template <typename T>
using AggregateFuncAvgDecimal256 = typename AvgDecimal256<T>::Function;

void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("avg", creator_with_type::creator<AggregateFuncAvg>);
factory.register_function_both("avg_decimal256",
creator_with_type::creator<AggregateFuncAvgDecimal256>);
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (attr.enable_decimal256) {
return creator_with_type::creator<AggregateFuncAvgDecimal256>(name, types,
result_is_nullable, attr);
} else {
return creator_with_type::creator<AggregateFuncAvg>(name, types, result_is_nullable,
attr);
}
};
factory.register_function_both("avg", creator);
}
} // namespace doris::vectorized
9 changes: 5 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) {
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_bitmap_union_count(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return std::make_shared<AggregateFunctionBitmapCount<true, ColumnBitmap>>(argument_types);
Expand All @@ -53,7 +53,8 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::str

AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return AggregateFunctionPtr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types)

AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return AggregateFunctionPtr(create_with_int_data_type<true>(argument_types));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n

AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
if (name == "array_agg") {
return create_aggregate_function_collect_impl<std::false_type, std::true_type>(
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ struct CorrMoment {

AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_binary(name, argument_types);
return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1],
argument_types, result_is_nullable);
Expand Down
9 changes: 5 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_at_most<1>(name, argument_types);

return std::make_shared<AggregateFunctionCount>(argument_types);
}

AggregateFunctionPtr create_aggregate_function_count_not_null_unary(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_count_not_null_unary(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_at_most<1>(name, argument_types);

return std::make_shared<AggregateFunctionCountNotNullUnary>(argument_types);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() < 1) {
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}",
argument_types.size(), name);
Expand Down
6 changes: 4 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@ AggregateFunctionPtr create_function_single_value(const String& name,
template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData,
is_nullable>(name, argument_types, result_is_nullable,
NULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionPop, CovarName, PopData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_";

void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
// 1. we should get not nullable types;
DataTypes nested_types(types.size());
std::transform(types.begin(), types.end(), nested_types.begin(),
[](const auto& e) { return remove_nullable(e); });
auto function_combinator = std::make_shared<AggregateFunctionCombinatorDistinct>();
auto transform_arguments = function_combinator->transform_arguments(nested_types);
auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size());
auto nested_function = factory.get(nested_function_name, transform_arguments);
auto nested_function = factory.get(nested_function_name, transform_arguments, false,
BeExecVersionManager::get_newest_version(), attr);
return function_combinator->transform_aggregate_function(nested_function, types,
result_is_nullable);
};
Expand Down
8 changes: 5 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
namespace doris::vectorized {

void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) -> AggregateFunctionPtr {
AggregateFunctionCreator creator =
[&](const std::string& name, const DataTypes& types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) -> AggregateFunctionPtr {
const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX;
DataTypes transform_arguments;
for (const auto& t : types) {
Expand All @@ -45,7 +46,8 @@ void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFacto
}
auto nested_function_name = name.substr(0, name.size() - suffix.size());
auto nested_function =
factory.get(nested_function_name, transform_arguments, result_is_nullable);
factory.get(nested_function_name, transform_arguments, result_is_nullable,
BeExecVersionManager::get_newest_version(), attr);
if (!nested_function) {
throw Exception(
ErrorCode::INTERNAL_ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl
}

AggregateFunctionPtr create_aggregate_function_group_array_intersect(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_unary(name, argument_types);
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const std::string AggregateFunctionGroupConcatImplStr::separator = ",";

AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
return creator_without_type::create<
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_typ

AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ AggregateFunctionPtr create_agg_function_map_agg(const DataTypes& argument_types

AggregateFunctionPtr create_aggregate_function_map_agg(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace doris::vectorized {
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_unary(name, argument_types);

AggregateFunctionPtr res(creator_with_numeric_type::create<AggregateFunctionsSingleValue, Data,
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_min_max.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,5 +671,6 @@ class AggregateFunctionsSingleValue final
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable);
const bool result_is_nullable,
const AggregateFunctionAttr& attr = {});
} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data>
AggregateFunctionPtr create_aggregate_function_min_max_by(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() != 2) {
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ template <template <typename> class Impl>
AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& name,
const DataTypes& argument_types,

const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.empty()) {
LOG(WARNING) << "Incorrect number of arguments for aggregate function " << name;
return nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
namespace doris::vectorized {

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_percentile_approx(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
Expand Down
16 changes: 9 additions & 7 deletions be/src/vec/aggregate_functions/aggregate_function_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,20 @@ struct AggregateFunctionProductData<Decimal128V2> {

void reset(Decimal128V2 value) { product = std::move(value); }
};
template <typename T>
concept DecimalTypeConcept = std::is_same_v<T, Decimal128V3> || std::is_same_v<T, Decimal256>;

template <>
struct AggregateFunctionProductData<Decimal128V3> {
Decimal128V3 product {};
template <DecimalTypeConcept T>
struct AggregateFunctionProductData<T> {
T product {};

template <typename NestedType>
void add(Decimal<NestedType> value, Decimal128V3 multiplier) {
void add(Decimal<NestedType> value, T multiplier) {
product *= value;
product /= multiplier;
}

void merge(const AggregateFunctionProductData& other, Decimal128V3 multiplier) {
void merge(const AggregateFunctionProductData& other, T multiplier) {
product *= other.product;
product /= multiplier;
}
Expand All @@ -96,9 +98,9 @@ struct AggregateFunctionProductData<Decimal128V3> {

void read(BufferReadable& buffer) { read_binary(product, buffer); }

Decimal128V2 get() const { return product; }
T get() const { return product; }

void reset(Decimal128V2 value) { product = value; }
void reset(T value) { product = std::move(value); }
};

template <typename T, typename TResult, typename Data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_quantile_state_union(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_quantile_state_union(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return std::make_shared<
Expand Down
Loading

0 comments on commit 64d0c55

Please sign in to comment.