diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index 1c3d133248aa..deafe5186264 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::fmt::Debug; +use std::marker::PhantomData; use std::ops::Range; use common_arrow::arrow::buffer::Buffer; @@ -26,8 +27,150 @@ use serde::Deserialize; use serde::Serialize; use super::SimpleDomain; +use crate::types::ArgType; +use crate::types::DataType; +use crate::types::GenericMap; +use crate::types::ValueType; use crate::utils::arrow::buffer_into_mut; use crate::Column; +use crate::ColumnBuilder; +use crate::Domain; +use crate::Scalar; +use crate::ScalarRef; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DecimalType(PhantomData); + +pub type Decimal128Type = DecimalType; +pub type Decimal256Type = DecimalType; + +impl ValueType for DecimalType { + type Scalar = DecimalScalar; + type ScalarRef<'a> = DecimalScalar; + type Column = DecimalColumn; + type Domain = DecimalDomain; + type ColumnIterator<'a> = std::iter::Cloned>; + type ColumnBuilder = DecimalColumnBuilder; + + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: DecimalScalar) -> DecimalScalar { + long + } + + fn to_owned_scalar<'a>(scalar: Self::ScalarRef<'a>) -> Self::Scalar { + scalar + } + + fn to_scalar_ref<'a>(scalar: &'a Self::Scalar) -> Self::ScalarRef<'a> { + *scalar + } + + fn try_downcast_scalar<'a>(scalar: &'a ScalarRef) -> Option> { + match scalar { + ScalarRef::Decimal(scalar) => Some(*scalar), + _ => None, + } + } + + fn try_downcast_column<'a>(col: &'a Column) -> Option { + col.as_decimal().cloned() + } + + fn try_downcast_domain(domain: &Domain) -> Option { + domain.as_decimal().map(DecimalDomain::clone) + } + + fn try_downcast_builder<'a>( + builder: &'a mut ColumnBuilder, + ) -> Option<&'a mut Self::ColumnBuilder> { + match builder { + ColumnBuilder::Decimal(builder) => Some(builder), + _ => None, + } + } + + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { + Scalar::Decimal(scalar) + } + + fn upcast_column(col: Self::Column) -> Column { + Column::Decimal(col) + } + + fn upcast_domain(domain: Self::Domain) -> Domain { + Domain::Decimal(domain) + } + + fn column_len<'a>(col: &'a Self::Column) -> usize { + col.len() + } + + fn index_column<'a>(col: &'a Self::Column, index: usize) -> Option> { + col.index(index) + } + + unsafe fn index_column_unchecked<'a>( + col: &'a Self::Column, + index: usize, + ) -> Self::ScalarRef<'a> { + col.index_unchecked(index) + } + + fn slice_column<'a>(col: &'a Self::Column, range: Range) -> Self::Column { + col.slice(range) + } + + fn iter_column<'a>(col: &'a Self::Column) -> Self::ColumnIterator<'a> { + todo!() + } + + fn column_to_builder(col: Self::Column) -> Self::ColumnBuilder { + DecimalColumnBuilder::from_column(col) + } + + fn builder_len(builder: &Self::ColumnBuilder) -> usize { + builder.len() + } + + fn push_item(builder: &mut Self::ColumnBuilder, item: Self::Scalar) { + builder.push(item); + } + + fn push_default(builder: &mut Self::ColumnBuilder) { + builder.push_default(); + } + + fn append_column(builder: &mut Self::ColumnBuilder, other: &Self::Column) { + builder.append_column(other) + } + + fn build_column(builder: Self::ColumnBuilder) -> Self::Column { + builder.build() + } + + fn build_scalar(builder: Self::ColumnBuilder) -> Self::Scalar { + builder.build_scalar() + } +} + +impl ArgType for DecimalType { + fn data_type() -> DataType { + DataType::Decimal(Num::data_type()) + } + + fn full_domain() -> Self::Domain { + todo!() + } + + fn create_builder(capacity: usize, generics: &GenericMap) -> Self::ColumnBuilder { + match generics[0] { + DataType::Decimal(decimal_ty) => { + DecimalColumnBuilder::with_capacity(&decimal_ty, capacity) + } + _ => unreachable!(), + } + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumAsInner)] pub enum DecimalDataType { @@ -65,7 +208,9 @@ pub struct DecimalSize { pub scale: u8, } -pub trait Decimal: Sized { +pub trait Decimal: + Copy + Debug + Default + Clone + Copy + PartialEq + Eq + PartialOrd + Ord + Sync + Send + 'static +{ fn one() -> Self; // 10**scale fn e(n: u32) -> Self; @@ -82,6 +227,8 @@ pub trait Decimal: Sized { fn to_column(value: Vec, size: DecimalSize) -> DecimalColumn { Self::to_column_from_buffer(value.into(), size) } + + fn data_type() -> DecimalDataType; } impl Decimal for i128 { @@ -117,6 +264,13 @@ impl Decimal for i128 { DecimalColumn::Decimal256(_, _) => None, } } + + fn data_type() -> DecimalDataType { + DecimalDataType::Decimal128(DecimalSize { + precision: 0, + scale: 0, + }) + } } impl Decimal for i256 { @@ -153,6 +307,13 @@ impl Decimal for i256 { DecimalColumn::Decimal256(c, size) => Some((c.clone(), *size)), } } + + fn data_type() -> DecimalDataType { + DecimalDataType::Decimal256(DecimalSize { + precision: 0, + scale: 0, + }) + } } static MAX_DECIMAL128_PRECISION: u8 = 38; diff --git a/src/query/functions/src/aggregates/aggregate_min_max_any.rs b/src/query/functions/src/aggregates/aggregate_min_max_any.rs index d1131d13c139..f45bc9b8d214 100644 --- a/src/query/functions/src/aggregates/aggregate_min_max_any.rs +++ b/src/query/functions/src/aggregates/aggregate_min_max_any.rs @@ -20,12 +20,15 @@ use std::sync::Arc; use common_arrow::arrow::bitmap::Bitmap; use common_exception::ErrorCode; use common_exception::Result; +use common_expression::types::decimal::*; use common_expression::types::number::*; use common_expression::types::*; +use common_expression::with_decimal_mapped_type; use common_expression::with_number_mapped_type; use common_expression::Column; use common_expression::ColumnBuilder; use common_expression::Scalar; +use ethnum::i256; use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_scalar_state::need_manual_drop_state; @@ -176,11 +179,13 @@ where pub fn try_create_aggregate_min_max_any_function( display_name: &str, - _params: Vec, + params: Vec, argument_types: Vec, ) -> Result> { assert_unary_arguments(display_name, argument_types.len())?; let mut data_type = argument_types[0].clone(); + println!("the data_type is {:?}", data_type.clone()); + println!("the params is {:?}", params.clone()); let need_drop = need_manual_drop_state(&data_type); // null use dummy func, it's already covered in `AggregateNullResultFunction` @@ -211,6 +216,21 @@ pub fn try_create_aggregate_min_max_any_function( } }) } + DataType::Decimal(decimal_type) => { + let scale = decimal_type.scale(); + let precision = decimal_type.precision(); + let size = DecimalSize { scale, precision }; + with_decimal_mapped_type!(|DECIMAL| match decimal_type { + DecimalDataType::DECIMAL(size) => { + type State = ScalarState, CMP>; + AggregateMinMaxAnyFunction::, CMP, State>::try_create( + display_name, + data_type, + need_drop, + ) + } + }) + } _ => { type State = ScalarState; AggregateMinMaxAnyFunction::::try_create(