From d18ceefad555a051033b8bd10c0643f3ff71ca91 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Thu, 16 Dec 2021 06:34:20 +0100 Subject: [PATCH] Added support for scalar comparison of dictionary. --- Cargo.toml | 2 +- src/compute/comparison/mod.rs | 13 +++++++ src/compute/take/mod.rs | 2 + src/types/index.rs | 71 ++++++++++++---------------------- tests/it/compute/comparison.rs | 1 + 5 files changed, 41 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1803bea47ec..1b02c0a6010 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -158,7 +158,7 @@ compute_bitwise = [] compute_boolean = [] compute_boolean_kleene = [] compute_cast = ["lexical-core", "compute_take"] -compute_comparison = [] +compute_comparison = ["compute_take"] compute_concatenate = [] compute_contains = [] compute_filter = [] diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 836eab497b4..e214711c918 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -56,6 +56,7 @@ pub mod utf8; mod simd; pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use super::take::take_boolean; pub(crate) use primitive::{ compare_values_op as primitive_compare_values_op, compare_values_op_scalar as primitive_compare_values_op_scalar, @@ -266,6 +267,14 @@ macro_rules! compare_scalar { let rhs = rhs.as_any().downcast_ref::>().unwrap(); binary::$op::(lhs, rhs.value().unwrap()) } + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let values = $op(lhs.values().as_ref(), rhs); + + take_boolean(&values, lhs.keys()) + }) + } _ => todo!("Comparisons of {:?} are not yet supported", lhs.data_type()), } }}; @@ -363,6 +372,10 @@ pub fn can_gt_eq(data_type: &DataType) -> bool { // The list of operations currently supported. fn can_partial_eq_and_ord(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { + return can_partial_eq_and_ord(values.as_ref()); + } + matches!( data_type, DataType::Boolean diff --git a/src/compute/take/mod.rs b/src/compute/take/mod.rs index b127ddd5c4a..3d71e098d0a 100644 --- a/src/compute/take/mod.rs +++ b/src/compute/take/mod.rs @@ -33,6 +33,8 @@ mod primitive; mod structure; mod utf8; +pub(crate) use boolean::take as take_boolean; + /// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls. /// The returned array has a length equal to `indices.len()`. pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result> { diff --git a/src/types/index.rs b/src/types/index.rs index ae040ec7709..b44b3957e79 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -11,10 +11,10 @@ pub trait Index: + std::ops::AddAssign + std::ops::Sub + num_traits::One - + PartialOrd + num_traits::Num - + Ord + num_traits::CheckedAdd + + PartialOrd + + Ord { /// Convert itself to [`usize`]. fn to_usize(&self) -> usize; @@ -32,53 +32,30 @@ pub trait Index: } } -impl Index for i32 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } -} - -impl Index for i64 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } -} - -impl Index for u32 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } +macro_rules! index { + ($t:ty) => { + impl Index for $t { + #[inline] + fn to_usize(&self) -> usize { + *self as usize + } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } + } + }; } -impl Index for u64 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } -} +index!(i8); +index!(i16); +index!(i32); +index!(i64); +index!(u8); +index!(u16); +index!(u32); +index!(u64); /// Range of [`Index`], equivalent to `(a..b)`. /// `Step` is unstable in Rust, which does not allow us to implement (a..b) for [`Index`]. diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index 1c89cc84620..61123985793 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -41,6 +41,7 @@ fn consistency() { Duration(TimeUnit::Millisecond), Duration(TimeUnit::Microsecond), Duration(TimeUnit::Nanosecond), + Dictionary(IntegerType::Int32, Box::new(LargeBinary)), ]; // array <> array