Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](function)adjust aggregate function's nullable property #37330

Merged
merged 10 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions be/src/agent/be_exec_version_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ class BeExecVersionManager {
* a. change the impl of percentile (need fix)
* b. clear old version of version 3->4
* c. change FunctionIsIPAddressInRange from AlwaysNotNullable to DependOnArguments
* d. change some agg function nullable property: PR #37215
*/
constexpr inline int BeExecVersionManager::max_be_exec_version = 5;
constexpr inline int BeExecVersionManager::min_be_exec_version = 0;

/// functional
constexpr inline int BITMAP_SERDE = 3;
constexpr inline int USE_NEW_SERDE = 4; // release on DORIS version 2.1
constexpr inline int OLD_WAL_SERDE = 3; // use to solve compatibility issues, see pr #32299
constexpr inline int USE_NEW_SERDE = 4; // release on DORIS version 2.1
constexpr inline int OLD_WAL_SERDE = 3; // use to solve compatibility issues, see pr #32299
constexpr inline int AGG_FUNCTION_NULLABLE = 5; // change some agg nullable property: PR #37215

} // namespace doris
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/aggregation_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ Status AggSinkOperatorX::prepare(RuntimeState* state) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(
state, DataSinkOperatorX<AggSinkLocalState>::_child_x->row_desc(),
intermediate_slot_desc, output_slot_desc));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

_offsets_of_aggregate_states.resize(_aggregate_evaluators.size());
Expand Down Expand Up @@ -832,7 +833,6 @@ Status AggSinkOperatorX::open(RuntimeState* state) {

for (auto& _aggregate_evaluator : _aggregate_evaluators) {
RETURN_IF_ERROR(_aggregate_evaluator->open(state));
_aggregate_evaluator->set_version(state->be_exec_version());
}

return Status::OK();
Expand Down
1 change: 1 addition & 0 deletions be/src/pipeline/exec/analytic_source_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ Status AnalyticSourceOperatorX::prepare(RuntimeState* state) {
SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[i];
RETURN_IF_ERROR(_agg_functions[i]->prepare(state, _child_x->row_desc(),
intermediate_slot_desc, output_slot_desc));
_agg_functions[i]->set_version(state->be_exec_version());
_change_to_nullable_flags.push_back(output_slot_desc->is_nullable() &&
!_agg_functions[i]->data_type()->is_nullable());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ Status DistinctStreamingAggOperatorX::prepare(RuntimeState* state) {
SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[j];
RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(
state, _child_x->row_desc(), intermediate_slot_desc, output_slot_desc));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) {
Expand Down Expand Up @@ -421,7 +422,6 @@ Status DistinctStreamingAggOperatorX::open(RuntimeState* state) {

for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->open(state));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/streaming_aggregation_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ Status StreamingAggOperatorX::prepare(RuntimeState* state) {
SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[j];
RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(
state, _child_x->row_desc(), intermediate_slot_desc, output_slot_desc));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

_offsets_of_aggregate_states.resize(_aggregate_evaluators.size());
Expand Down Expand Up @@ -1239,7 +1240,6 @@ Status StreamingAggOperatorX::open(RuntimeState* state) {

for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->open(state));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

return Status::OK();
Expand Down
25 changes: 19 additions & 6 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ AggregateFunctionPtr create_function_single_value(const String& name,
}

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_covariance_samp_old(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return create_function_single_value<AggregateFunctionSamp_OLDER, CovarSampName, SampData_OLDER,
is_nullable>(name, argument_types, result_is_nullable,
NULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData,
is_nullable>(name, argument_types, result_is_nullable,
NULLABLE);
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
Expand All @@ -81,9 +88,15 @@ void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& facto
factory.register_alias("covar", "covar_pop");
}

void register_aggregate_function_covar_samp_old(AggregateFunctionSimpleFactory& factory) {
factory.register_alternative_function(
"covar_samp", create_aggregate_function_covariance_samp_old<NOTNULLABLE>);
factory.register_alternative_function(
"covar_samp", create_aggregate_function_covariance_samp_old<NULLABLE>, NULLABLE);
}

void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory) {
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NOTNULLABLE>);
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NULLABLE>,
NULLABLE);
factory.register_function_both("covar_samp", create_aggregate_function_covariance_samp);
register_aggregate_function_covar_samp_old(factory);
}
} // namespace doris::vectorized
37 changes: 34 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_covar.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include "agent/be_exec_version_manager.h"
#define POP true
#define NOTPOP false
#define NULLABLE true
Expand Down Expand Up @@ -224,7 +225,7 @@ struct PopData : Data {
};

template <typename T, typename Data>
struct SampData : Data {
struct SampData_OLDER : Data {
using ColVecResult =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128V2>, ColumnFloat64>;
void insert_result_into(IColumn& to) const {
Expand All @@ -243,6 +244,24 @@ struct SampData : Data {
}
};

template <typename T, typename Data>
struct SampData : Data {
using ColVecResult =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128V2>, ColumnFloat64>;
void insert_result_into(IColumn& to) const {
auto& col = assert_cast<ColVecResult&>(to);
if (this->count == 1 || this->count == 0) {
col.insert_default();
} else {
if constexpr (IsDecimalNumber<T>) {
col.get_data().push_back(this->get_samp_result().value());
} else {
col.get_data().push_back(this->get_samp_result());
}
}
}
};

template <typename Data>
struct CovarName : Data {
static const char* name() { return "covar"; }
Expand All @@ -269,7 +288,11 @@ class AggregateFunctionSampCovariance
if constexpr (is_pop) {
return Data::get_return_type();
} else {
return make_nullable(Data::get_return_type());
if (IAggregateFunction::version < AGG_FUNCTION_NULLABLE) {
return make_nullable(Data::get_return_type());
} else {
return Data::get_return_type();
}
}
}

Expand All @@ -278,7 +301,7 @@ class AggregateFunctionSampCovariance
if constexpr (is_pop) {
this->data(place).add(columns[0], columns[1], row_num);
} else {
if constexpr (is_nullable) {
if constexpr (is_nullable) { //this if check could remove with old function
const auto* nullable_column_x = check_and_get_column<ColumnNullable>(columns[0]);
const auto* nullable_column_y = check_and_get_column<ColumnNullable>(columns[1]);
if (!nullable_column_x->is_null_at(row_num) &&
Expand Down Expand Up @@ -313,6 +336,14 @@ class AggregateFunctionSampCovariance
}
};

template <typename Data, bool is_nullable>
class AggregateFunctionSamp_OLDER final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
public:
AggregateFunctionSamp_OLDER(const DataTypes& argument_types_)
: AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable>(argument_types_) {}
};

template <typename Data, bool is_nullable>
class AggregateFunctionSamp final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
Expand Down
82 changes: 65 additions & 17 deletions be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,47 @@
namespace doris::vectorized {

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_percentile_approx_older(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 2) {
return creator_without_type::create<
AggregateFunctionPercentileApproxTwoParams<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
AggregateFunctionPercentileApproxTwoParams_OLDER<is_nullable>>((argument_types),
result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<
AggregateFunctionPercentileApproxThreeParams<is_nullable>>(
AggregateFunctionPercentileApproxThreeParams_OLDER<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
}
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 2) {
return creator_without_type::create<AggregateFunctionPercentileApproxTwoParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxThreeParams>(
argument_types, result_is_nullable);
}
return nullptr;
}

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted_older(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
Expand All @@ -55,17 +73,35 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
}
if (argument_types.size() == 3) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedThreeParams<is_nullable>>(
AggregateFunctionPercentileApproxWeightedThreeParams_OLDER<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedFourParams<is_nullable>>(
AggregateFunctionPercentileApproxWeightedFourParams_OLDER<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
}
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxWeightedThreeParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<AggregateFunctionPercentileApproxWeightedFourParams>(
argument_types, result_is_nullable);
}
return nullptr;
}

void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("percentile",
creator_with_integer_type::creator<AggregateFunctionPercentile>);
Expand All @@ -74,14 +110,26 @@ void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& fact
creator_with_integer_type::creator<AggregateFunctionPercentileArray>);
}

void register_percentile_approx_old_function(AggregateFunctionSimpleFactory& factory) {
factory.register_alternative_function(
"percentile_approx", create_aggregate_function_percentile_approx_older<false>, false);
factory.register_alternative_function(
"percentile_approx", create_aggregate_function_percentile_approx_older<true>, true);
factory.register_alternative_function(
"percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted_older<false>, false);
factory.register_alternative_function(
"percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted_older<true>, true);
}

void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) {
factory.register_function("percentile_approx",
create_aggregate_function_percentile_approx<false>, false);
factory.register_function("percentile_approx",
create_aggregate_function_percentile_approx<true>, true);
factory.register_function("percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted<false>, false);
factory.register_function("percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted<true>, true);
factory.register_function_both("percentile_approx",
create_aggregate_function_percentile_approx);
factory.register_function_both("percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted);

register_percentile_approx_old_function(factory);
}

} // namespace doris::vectorized
Loading
Loading