Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](function) stddev with DecimalV2 type will result in an error #38731

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,6 @@ AggregateFunctionPtr create_function_single_value(const String& name,
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<AggregateFunctionTemplate< \
NameData<Data<TYPE, BaseDatadecimal<TYPE, is_stddev>>>, is_nullable>>( \
custom_nullable ? remove_nullable(argument_types) : argument_types, \
result_is_nullable);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << fmt::format("create_function_single_value with unknowed type {}",
argument_types[0]->get_name());
return nullptr;
Expand Down
105 changes: 4 additions & 101 deletions be/src/vec/aggregate_functions/aggregate_function_stddev.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@
#include <stddef.h>
#include <stdint.h>

#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'boost/iterator/iterator_facade.hpp' file not found [clang-diagnostic-error]

#include <boost/iterator/iterator_facade.hpp>
         ^

#include <cmath>
#include <memory>
#include <type_traits>

#include "agent/be_exec_version_manager.h"
#include "olap/olap_common.h"
#include "runtime/decimalv2_value.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_nullable.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_number.h"
Expand Down Expand Up @@ -126,103 +122,6 @@ struct BaseData {
int64_t count;
};

template <typename T, bool is_stddev>
struct BaseDatadecimal {
BaseDatadecimal() : mean(0), m2(0), count(0) {}
virtual ~BaseDatadecimal() = default;

void write(BufferWritable& buf) const {
write_binary(mean, buf);
write_binary(m2, buf);
write_binary(count, buf);
}

void read(BufferReadable& buf) {
read_binary(mean, buf);
read_binary(m2, buf);
read_binary(count, buf);
}

void reset() {
mean = DecimalV2Value();
m2 = DecimalV2Value();
count = {};
}

DecimalV2Value get_result(DecimalV2Value res) const {
if constexpr (is_stddev) {
return DecimalV2Value::sqrt(res);
} else {
return res;
}
}

DecimalV2Value get_pop_result() const {
DecimalV2Value new_count = DecimalV2Value();
if (count == 1) {
return new_count;
}
DecimalV2Value res = m2 / new_count.assign_from_double(count);
return get_result(res);
}

DecimalV2Value get_samp_result() const {
DecimalV2Value new_count = DecimalV2Value();
DecimalV2Value res = m2 / new_count.assign_from_double(count - 1);
return get_result(res);
}

void merge(const BaseDatadecimal& rhs) {
if (rhs.count == 0) {
return;
}
DecimalV2Value new_count = DecimalV2Value();
new_count.assign_from_double(count);
DecimalV2Value rhs_count = DecimalV2Value();
rhs_count.assign_from_double(rhs.count);

DecimalV2Value delta = mean - rhs.mean;
DecimalV2Value sum_count = new_count + rhs_count;
mean = rhs.mean + delta * (new_count / sum_count);
m2 = rhs.m2 + m2 + (delta * delta) * (rhs_count * new_count / sum_count);
count += rhs.count;
}

void add(const IColumn* column, size_t row_num) {
const auto& sources = assert_cast<const ColumnDecimal<T>&>(*column);
Field field = sources[row_num];
auto decimal_field = field.template get<DecimalField<T>>();
int128_t value;
if (decimal_field.get_scale() > DecimalV2Value::SCALE) {
value = static_cast<int128_t>(decimal_field.get_value()) /
(decimal_field.get_scale_multiplier() / DecimalV2Value::ONE_BILLION);
} else {
value = static_cast<int128_t>(decimal_field.get_value()) *
(DecimalV2Value::ONE_BILLION / decimal_field.get_scale_multiplier());
}
DecimalV2Value source_data = DecimalV2Value(value);

DecimalV2Value new_count = DecimalV2Value();
new_count.assign_from_double(count);
DecimalV2Value increase_count = DecimalV2Value();
increase_count.assign_from_double(1 + count);

DecimalV2Value delta = source_data - mean;
DecimalV2Value r = delta / increase_count;
mean += r;
m2 += new_count * delta * r;
count += 1;
}

static DataTypePtr get_return_type() {
return std::make_shared<DataTypeDecimal<Decimal128V2>>(27, 9);
}

DecimalV2Value mean;
DecimalV2Value m2;
int64_t count;
};

template <typename T, typename Data>
struct PopData : Data {
using ColVecResult =
Expand All @@ -237,6 +136,10 @@ struct PopData : Data {
}
};

// For this series of functions, the Decimal type is not supported
// because the operations involve squaring,
// which can easily exceed the range of the Decimal type.

template <typename Data>
struct StddevName : Data {
static const char* name() { return "stddev"; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -49,9 +48,7 @@ public class Stddev extends NullableAggregateFunction
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE));

/**
* constructor with 1 argument.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -49,9 +48,7 @@ public class StddevSamp extends NullableAggregateFunction
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE));

/**
* constructor with 1 argument.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -49,9 +48,7 @@ public class Variance extends NullableAggregateFunction
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE));

/**
* constructor with 1 argument.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -48,9 +47,7 @@ public class VarianceSamp extends NullableAggregateFunction
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE));

/**
* constructor with 1 argument.
Expand Down
Loading