diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 8e760da21909..5ae603885161 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -1221,7 +1221,7 @@ mod tests { .into_iter() .collect(); let sliced_input = sliced_input.slice(4, 2); - let sliced_input = as_primitive_array::(&sliced_input); + let sliced_input = sliced_input.as_primitive::(); assert_eq!(sliced_input, &input); @@ -1244,7 +1244,7 @@ mod tests { .into_iter() .collect(); let sliced_input = sliced_input.slice(4, 2); - let sliced_input = as_boolean_array(&sliced_input); + let sliced_input = sliced_input.as_boolean(); assert_eq!(sliced_input, &input); @@ -1267,7 +1267,7 @@ mod tests { .into_iter() .collect(); let sliced_input = sliced_input.slice(4, 2); - let sliced_input = as_string_array(&sliced_input); + let sliced_input = sliced_input.as_string::(); assert_eq!(sliced_input, &input); @@ -1290,7 +1290,7 @@ mod tests { .into_iter() .collect(); let sliced_input = sliced_input.slice(4, 2); - let sliced_input = as_generic_binary_array::(&sliced_input); + let sliced_input = sliced_input.as_binary::(); assert_eq!(sliced_input, &input); diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 8e2b7915357a..de4b0ccb8858 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -728,20 +728,20 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { - let l = as_primitive_array::(left); + let l = left.as_primitive::(); match right.data_type() { DataType::Interval(IntervalUnit::YearMonth) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date32Type::add_year_months)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::DayTime) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date32Type::add_day_time)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date32Type::add_month_day_nano)?; Ok(Arc::new(res)) } @@ -752,20 +752,20 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { - let l = as_primitive_array::(left); + let l = left.as_primitive::(); match right.data_type() { DataType::Interval(IntervalUnit::YearMonth) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date64Type::add_year_months)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::DayTime) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date64Type::add_day_time)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date64Type::add_month_day_nano)?; Ok(Arc::new(res)) } @@ -808,20 +808,20 @@ pub fn add_dyn_checked( ) } DataType::Date32 => { - let l = as_primitive_array::(left); + let l = left.as_primitive::(); match right.data_type() { DataType::Interval(IntervalUnit::YearMonth) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date32Type::add_year_months)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::DayTime) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date32Type::add_day_time)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date32Type::add_month_day_nano)?; Ok(Arc::new(res)) } @@ -832,20 +832,20 @@ pub fn add_dyn_checked( } } DataType::Date64 => { - let l = as_primitive_array::(left); + let l = left.as_primitive::(); match right.data_type() { DataType::Interval(IntervalUnit::YearMonth) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date64Type::add_year_months)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::DayTime) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date64Type::add_day_time)?; Ok(Arc::new(res)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = as_primitive_array::(right); + let r = right.as_primitive::(); let res = math_op(l, r, Date64Type::add_month_day_nano)?; Ok(Arc::new(res)) } @@ -2079,8 +2079,7 @@ mod tests { fn test_primitive_array_add_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); let a = a.slice(1, 4); - let a = as_primitive_array(&a); - let actual = add_scalar(a, 3).unwrap(); + let actual = add_scalar(a.as_primitive(), 3).unwrap(); let expected = Int32Array::from(vec![None, Some(12), Some(11), None]); assert_eq!(actual, expected); } @@ -2110,8 +2109,7 @@ mod tests { fn test_primitive_array_subtract_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); let a = a.slice(1, 4); - let a = as_primitive_array(&a); - let actual = subtract_scalar(a, 3).unwrap(); + let actual = subtract_scalar(a.as_primitive(), 3).unwrap(); let expected = Int32Array::from(vec![None, Some(6), Some(5), None]); assert_eq!(actual, expected); } @@ -2141,8 +2139,7 @@ mod tests { fn test_primitive_array_multiply_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); let a = a.slice(1, 4); - let a = as_primitive_array(&a); - let actual = multiply_scalar(a, 3).unwrap(); + let actual = multiply_scalar(a.as_primitive(), 3).unwrap(); let expected = Int32Array::from(vec![None, Some(27), Some(24), None]); assert_eq!(actual, expected); } @@ -2171,7 +2168,7 @@ mod tests { assert_eq!(0, c.value(4)); let c = modulus_dyn(&a, &b).unwrap(); - let c = as_primitive_array::(&c); + let c = c.as_primitive::(); assert_eq!(0, c.value(0)); assert_eq!(3, c.value(1)); assert_eq!(0, c.value(2)); @@ -2262,8 +2259,7 @@ mod tests { fn test_primitive_array_divide_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); let a = a.slice(1, 4); - let a = as_primitive_array(&a); - let actual = divide_scalar(a, 3).unwrap(); + let actual = divide_scalar(a.as_primitive(), 3).unwrap(); let expected = Int32Array::from(vec![None, Some(3), Some(2), None]); assert_eq!(actual, expected); } @@ -2277,7 +2273,7 @@ mod tests { assert_eq!(c, expected); let c = modulus_scalar_dyn::(&a, b).unwrap(); - let c = as_primitive_array::(&c); + let c = c.as_primitive::(); let expected = Int32Array::from(vec![0, 2, 0, 2, 1]); assert_eq!(c, &expected); } @@ -2286,13 +2282,13 @@ mod tests { fn test_int_array_modulus_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); let a = a.slice(1, 4); - let a = as_primitive_array(&a); + let a = a.as_primitive(); let actual = modulus_scalar(a, 3).unwrap(); let expected = Int32Array::from(vec![None, Some(0), Some(2), None]); assert_eq!(actual, expected); let actual = modulus_scalar_dyn::(a, 3).unwrap(); - let actual = as_primitive_array::(&actual); + let actual = actual.as_primitive::(); let expected = Int32Array::from(vec![None, Some(0), Some(2), None]); assert_eq!(actual, &expected); } @@ -2313,7 +2309,7 @@ mod tests { assert_eq!(0, result.value(0)); let result = modulus_scalar_dyn::(&a, -1).unwrap(); - let result = as_primitive_array::(&result); + let result = result.as_primitive::(); assert_eq!(0, result.value(0)); } @@ -3295,7 +3291,8 @@ mod tests { .unwrap(); let result = add_scalar_dyn::(&a, 1).unwrap(); - let result = as_primitive_array::(&result) + let result = result + .as_primitive::() .clone() .with_precision_and_scale(38, 2) .unwrap(); diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index 0a8815cc8059..96a856d46b67 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -523,7 +523,7 @@ mod tests { let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); let input_slice = input.slice(1, 4); - let input_slice: &Float64Array = as_primitive_array(&input_slice); + let input_slice: &Float64Array = input_slice.as_primitive(); let result = unary(input_slice, |n| n.round()); assert_eq!( result, diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index ee58a485c71c..0862230a499e 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -16,7 +16,7 @@ // under the License. use crate::builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}; -use crate::cast::as_primitive_array; +use crate::cast::AsArray; use crate::iterator::ArrayIter; use crate::types::*; use crate::{ @@ -410,8 +410,8 @@ impl DictionaryArray { return Err(self); } - let key_array = as_primitive_array::(self.keys()).clone(); - let value_array = as_primitive_array::(self.values()).clone(); + let key_array = self.keys().clone(); + let value_array = self.values().as_primitive::().clone(); drop(self.data); drop(self.keys); diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index 6cd627cbd838..c9651f0b2019 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -253,7 +253,7 @@ impl std::fmt::Debug for MapArray { #[cfg(test)] mod tests { - use crate::cast::as_primitive_array; + use crate::cast::AsArray; use crate::types::UInt32Type; use crate::{Int32Array, UInt32Array}; use std::sync::Arc; @@ -522,7 +522,7 @@ mod tests { assert_eq!( &values_data, - as_primitive_array::(map_array.values()) + map_array.values().as_primitive::() ); assert_eq!(&DataType::UInt32, map_array.value_type()); assert_eq!(3, map_array.len()); diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index f62da38fb241..3aefb53b83f6 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -555,7 +555,7 @@ mod tests { use super::*; use crate::builder::PrimitiveRunBuilder; - use crate::cast::as_primitive_array; + use crate::cast::AsArray; use crate::types::{Int16Type, Int32Type, Int8Type, UInt32Type}; use crate::{Array, Int32Array, StringArray}; @@ -877,8 +877,7 @@ mod tests { builder.extend(input_array.clone().into_iter()); let run_array = builder.finish(); - let physical_values_array = - as_primitive_array::(run_array.values()); + let physical_values_array = run_array.values().as_primitive::(); // create an array consisting of all the indices repeated twice and shuffled. let mut logical_indices: Vec = (0_u32..(logical_len as u32)).collect(); @@ -913,7 +912,7 @@ mod tests { PrimitiveRunBuilder::::with_capacity(input_array.len()); builder.extend(input_array.iter().copied()); let run_array = builder.finish(); - let physical_values_array = as_primitive_array::(run_array.values()); + let physical_values_array = run_array.values().as_primitive::(); // test for all slice lengths. for slice_len in 1..=total_len { diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index 5a4d2af7ca45..fe227226f77d 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -398,7 +398,7 @@ mod tests { use super::*; use crate::builder::UnionBuilder; - use crate::cast::{as_primitive_array, as_string_array}; + use crate::cast::AsArray; use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type}; use crate::RecordBatch; use crate::{Float64Array, Int32Array, Int64Array, StringArray}; @@ -1078,36 +1078,36 @@ mod tests { let v = array.value(0); assert_eq!(v.data_type(), &DataType::Int32); assert_eq!(v.len(), 1); - assert_eq!(as_primitive_array::(v.as_ref()).value(0), 5); + assert_eq!(v.as_primitive::().value(0), 5); let v = array.value(1); assert_eq!(v.data_type(), &DataType::Utf8); assert_eq!(v.len(), 1); - assert_eq!(as_string_array(v.as_ref()).value(0), "foo"); + assert_eq!(v.as_string::().value(0), "foo"); let v = array.value(2); assert_eq!(v.data_type(), &DataType::Int32); assert_eq!(v.len(), 1); - assert_eq!(as_primitive_array::(v.as_ref()).value(0), 6); + assert_eq!(v.as_primitive::().value(0), 6); let v = array.value(3); assert_eq!(v.data_type(), &DataType::Utf8); assert_eq!(v.len(), 1); - assert_eq!(as_string_array(v.as_ref()).value(0), "bar"); + assert_eq!(v.as_string::().value(0), "bar"); let v = array.value(4); assert_eq!(v.data_type(), &DataType::Float64); assert_eq!(v.len(), 1); - assert_eq!(as_primitive_array::(v.as_ref()).value(0), 10.0); + assert_eq!(v.as_primitive::().value(0), 10.0); let v = array.value(5); assert_eq!(v.data_type(), &DataType::Int32); assert_eq!(v.len(), 1); - assert_eq!(as_primitive_array::(v.as_ref()).value(0), 4); + assert_eq!(v.as_primitive::().value(0), 4); let v = array.value(6); assert_eq!(v.data_type(), &DataType::Utf8); assert_eq!(v.len(), 1); - assert_eq!(as_string_array(v.as_ref()).value(0), "baz"); + assert_eq!(v.as_string::().value(0), "baz"); } } diff --git a/arrow-array/src/builder/generic_byte_run_builder.rs b/arrow-array/src/builder/generic_byte_run_builder.rs index 5c15b1544ed3..9c26d7be6904 100644 --- a/arrow-array/src/builder/generic_byte_run_builder.rs +++ b/arrow-array/src/builder/generic_byte_run_builder.rs @@ -40,7 +40,7 @@ use arrow_buffer::ArrowNativeType; /// # use arrow_array::{GenericByteArray, BinaryArray}; /// # use arrow_array::types::{BinaryType, Int16Type}; /// # use arrow_array::{Array, Int16Array}; -/// # use arrow_array::cast::as_generic_binary_array; +/// # use arrow_array::cast::AsArray; /// /// let mut builder = /// GenericByteRunBuilder::::new(); @@ -59,7 +59,7 @@ use arrow_buffer::ArrowNativeType; /// assert!(av.is_null(3)); /// /// // Values are polymorphic and so require a downcast. -/// let ava: &BinaryArray = as_generic_binary_array(av.as_ref()); +/// let ava: &BinaryArray = av.as_binary(); /// /// assert_eq!(ava.value(0), b"abc"); /// assert_eq!(ava.value(2), b"def"); @@ -318,7 +318,7 @@ where /// # use arrow_array::builder::StringRunBuilder; /// # use arrow_array::{Int16Array, StringArray}; /// # use arrow_array::types::Int16Type; -/// # use arrow_array::cast::as_string_array; +/// # use arrow_array::cast::AsArray; /// /// let mut builder = StringRunBuilder::::new(); /// @@ -332,7 +332,7 @@ where /// /// // Values are polymorphic and so require a downcast. /// let av = array.values(); -/// let ava: &StringArray = as_string_array(av.as_ref()); +/// let ava: &StringArray = av.as_string::(); /// /// assert_eq!(ava.value(0), "abc"); /// assert!(av.is_null(1)); @@ -353,8 +353,8 @@ pub type LargeStringRunBuilder = GenericByteRunBuilder; /// /// # use arrow_array::builder::BinaryRunBuilder; /// # use arrow_array::{BinaryArray, Int16Array}; +/// # use arrow_array::cast::AsArray; /// # use arrow_array::types::Int16Type; -/// # use arrow_array::cast::as_generic_binary_array; /// /// let mut builder = BinaryRunBuilder::::new(); /// @@ -368,7 +368,7 @@ pub type LargeStringRunBuilder = GenericByteRunBuilder; /// /// // Values are polymorphic and so require a downcast. /// let av = array.values(); -/// let ava: &BinaryArray = as_generic_binary_array::(av.as_ref()); +/// let ava: &BinaryArray = av.as_binary(); /// /// assert_eq!(ava.value(0), b"abc"); /// assert!(av.is_null(1)); @@ -387,7 +387,7 @@ mod tests { use super::*; use crate::array::Array; - use crate::cast::as_string_array; + use crate::cast::AsArray; use crate::types::{Int16Type, Int32Type}; use crate::GenericByteArray; use crate::Int16RunArray; @@ -518,7 +518,7 @@ mod tests { assert_eq!(array.len(), 10); assert_eq!(array.run_ends().values(), &[3, 5, 8, 10]); - let str_array = as_string_array(array.values().as_ref()); + let str_array = array.values().as_string::(); assert_eq!(str_array.value(0), "a"); assert_eq!(str_array.value(1), ""); assert_eq!(str_array.value(2), "b"); diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index 6228475542bd..c8643ac2822d 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -206,7 +206,7 @@ where mod tests { use super::*; use crate::builder::{Int32Builder, ListBuilder}; - use crate::cast::as_primitive_array; + use crate::cast::AsArray; use crate::types::Int32Type; use crate::{Array, Int32Array}; use arrow_buffer::Buffer; @@ -405,8 +405,7 @@ mod tests { assert_eq!(array.value_offsets(), [0, 4, 4, 6, 6]); assert_eq!(array.null_count(), 1); assert!(array.is_null(3)); - let a_values = array.values(); - let elements = as_primitive_array::(a_values.as_ref()); + let elements = array.values().as_primitive::(); assert_eq!(elements.values(), &[1, 2, 7, 0, 4, 5]); assert_eq!(elements.null_count(), 1); assert!(elements.is_null(3)); diff --git a/arrow-array/src/builder/primitive_run_builder.rs b/arrow-array/src/builder/primitive_run_builder.rs index e7c822ee6b19..30750b6f3421 100644 --- a/arrow-array/src/builder/primitive_run_builder.rs +++ b/arrow-array/src/builder/primitive_run_builder.rs @@ -30,7 +30,7 @@ use arrow_buffer::ArrowNativeType; /// ``` /// /// # use arrow_array::builder::PrimitiveRunBuilder; -/// # use arrow_array::cast::as_primitive_array; +/// # use arrow_array::cast::AsArray; /// # use arrow_array::types::{UInt32Type, Int16Type}; /// # use arrow_array::{Array, UInt32Array, Int16Array}; /// @@ -53,7 +53,7 @@ use arrow_buffer::ArrowNativeType; /// assert!(!av.is_null(2)); /// /// // Values are polymorphic and so require a downcast. -/// let ava: &UInt32Array = as_primitive_array::(av.as_ref()); +/// let ava: &UInt32Array = av.as_primitive::(); /// /// assert_eq!(ava, &UInt32Array::from(vec![Some(1234), None, Some(5678)])); /// ``` @@ -265,7 +265,7 @@ where #[cfg(test)] mod tests { use crate::builder::PrimitiveRunBuilder; - use crate::cast::as_primitive_array; + use crate::cast::AsArray; use crate::types::{Int16Type, UInt32Type}; use crate::{Array, UInt32Array}; @@ -293,7 +293,7 @@ mod tests { assert!(!av.is_null(2)); // Values are polymorphic and so require a downcast. - let ava: &UInt32Array = as_primitive_array::(av.as_ref()); + let ava: &UInt32Array = av.as_primitive::(); assert_eq!(ava, &UInt32Array::from(vec![Some(1234), None, Some(5678)])); } @@ -309,7 +309,7 @@ mod tests { assert_eq!(array.null_count(), 0); assert_eq!(array.run_ends().values(), &[1, 3, 5, 9, 10, 11]); assert_eq!( - as_primitive_array::(array.values().as_ref()).values(), + array.values().as_primitive::().values(), &[1, 2, 5, 4, 6, 2] ); } diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 81d250cafffe..a39ff88c6bcd 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -709,6 +709,165 @@ where T::from(array.to_data()) } +mod private { + pub trait Sealed {} +} + +/// An extension trait for `dyn Array` that provides ergonomic downcasting +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::Int32Type; +/// let col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; +/// assert_eq!(col.as_primitive::().values(), &[1, 2, 3]); +/// ``` +pub trait AsArray: private::Sealed { + /// Downcast this to a [`BooleanArray`] returning `None` if not possible + fn as_boolean_opt(&self) -> Option<&BooleanArray>; + + /// Downcast this to a [`BooleanArray`] panicking if not possible + fn as_boolean(&self) -> &BooleanArray { + self.as_boolean_opt().expect("boolean array") + } + + /// Downcast this to a [`PrimitiveArray`] returning `None` if not possible + fn as_primitive_opt(&self) -> Option<&PrimitiveArray>; + + /// Downcast this to a [`PrimitiveArray`] panicking if not possible + fn as_primitive(&self) -> &PrimitiveArray { + self.as_primitive_opt().expect("primitive array") + } + + /// Downcast this to a [`GenericByteArray`] returning `None` if not possible + fn as_bytes_opt(&self) -> Option<&GenericByteArray>; + + /// Downcast this to a [`GenericByteArray`] panicking if not possible + fn as_bytes(&self) -> &GenericByteArray { + self.as_bytes_opt().expect("byte array") + } + + /// Downcast this to a [`GenericStringArray`] returning `None` if not possible + fn as_string_opt(&self) -> Option<&GenericStringArray> { + self.as_bytes_opt() + } + + /// Downcast this to a [`GenericStringArray`] panicking if not possible + fn as_string(&self) -> &GenericStringArray { + self.as_bytes_opt().expect("string array") + } + + /// Downcast this to a [`GenericBinaryArray`] returning `None` if not possible + fn as_binary_opt(&self) -> Option<&GenericBinaryArray> { + self.as_bytes_opt() + } + + /// Downcast this to a [`GenericBinaryArray`] panicking if not possible + fn as_binary(&self) -> &GenericBinaryArray { + self.as_bytes_opt().expect("binary array") + } + + /// Downcast this to a [`StructArray`] returning `None` if not possible + fn as_struct_opt(&self) -> Option<&StructArray>; + + /// Downcast this to a [`StructArray`] panicking if not possible + fn as_struct(&self) -> &StructArray { + self.as_struct_opt().expect("struct array") + } + + /// Downcast this to a [`GenericListArray`] returning `None` if not possible + fn as_list_opt(&self) -> Option<&GenericListArray>; + + /// Downcast this to a [`GenericListArray`] panicking if not possible + fn as_list(&self) -> &GenericListArray { + self.as_list_opt().expect("list array") + } + + /// Downcast this to a [`MapArray`] returning `None` if not possible + fn as_map_opt(&self) -> Option<&MapArray>; + + /// Downcast this to a [`MapArray`] panicking if not possible + fn as_map(&self) -> &MapArray { + self.as_map_opt().expect("map array") + } + + /// Downcast this to a [`DictionaryArray`] returning `None` if not possible + fn as_dictionary_opt(&self) + -> Option<&DictionaryArray>; + + /// Downcast this to a [`DictionaryArray`] panicking if not possible + fn as_dictionary(&self) -> &DictionaryArray { + self.as_dictionary_opt().expect("dictionary array") + } +} + +impl private::Sealed for dyn Array + '_ {} +impl AsArray for dyn Array + '_ { + fn as_boolean_opt(&self) -> Option<&BooleanArray> { + self.as_any().downcast_ref() + } + + fn as_primitive_opt(&self) -> Option<&PrimitiveArray> { + self.as_any().downcast_ref() + } + + fn as_bytes_opt(&self) -> Option<&GenericByteArray> { + self.as_any().downcast_ref() + } + + fn as_struct_opt(&self) -> Option<&StructArray> { + self.as_any().downcast_ref() + } + + fn as_list_opt(&self) -> Option<&GenericListArray> { + self.as_any().downcast_ref() + } + + fn as_map_opt(&self) -> Option<&MapArray> { + self.as_any().downcast_ref() + } + + fn as_dictionary_opt( + &self, + ) -> Option<&DictionaryArray> { + self.as_any().downcast_ref() + } +} + +impl private::Sealed for ArrayRef {} +impl AsArray for ArrayRef { + fn as_boolean_opt(&self) -> Option<&BooleanArray> { + self.as_ref().as_boolean_opt() + } + + fn as_primitive_opt(&self) -> Option<&PrimitiveArray> { + self.as_ref().as_primitive_opt() + } + + fn as_bytes_opt(&self) -> Option<&GenericByteArray> { + self.as_ref().as_bytes_opt() + } + + fn as_struct_opt(&self) -> Option<&StructArray> { + self.as_ref().as_struct_opt() + } + + fn as_list_opt(&self) -> Option<&GenericListArray> { + self.as_ref().as_list_opt() + } + + fn as_map_opt(&self) -> Option<&MapArray> { + self.as_any().downcast_ref() + } + + fn as_dictionary_opt( + &self, + ) -> Option<&DictionaryArray> { + self.as_ref().as_dictionary_opt() + } +} + #[cfg(test)] mod tests { use arrow_buffer::i256; diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs index ada59564bf0e..ff1ddb1f67ce 100644 --- a/arrow-array/src/lib.rs +++ b/arrow-array/src/lib.rs @@ -43,17 +43,15 @@ //! } //! ``` //! -//! Additionally, there are convenient functions to do this casting -//! such as [`cast::as_primitive_array`] and [`cast::as_string_array`]: +//! The [`cast::AsArray`] extension trait can make this more ergonomic //! //! ``` //! # use arrow_array::Array; -//! # use arrow_array::cast::as_primitive_array; +//! # use arrow_array::cast::{AsArray, as_primitive_array}; //! # use arrow_array::types::Float32Type; //! //! fn as_f32_slice(array: &dyn Array) -> &[f32] { -//! // use as_primtive_array -//! as_primitive_array::(array).values() +//! array.as_primitive::().values() //! } //! ``` diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 43048c2aba45..ba909649da3a 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -445,9 +445,7 @@ fn cast_reinterpret_arrays< >( array: &dyn Array, ) -> Result { - Ok(Arc::new( - as_primitive_array::(array).reinterpret_cast::(), - )) + Ok(Arc::new(array.as_primitive::().reinterpret_cast::())) } fn cast_decimal_to_integer( @@ -716,7 +714,7 @@ pub fn cast_with_options( } (Decimal128(_, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal::( - as_primitive_array(array), + array.as_primitive(), *s1, *p2, *s2, @@ -725,7 +723,7 @@ pub fn cast_with_options( } (Decimal256(_, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal::( - as_primitive_array(array), + array.as_primitive(), *s1, *p2, *s2, @@ -734,7 +732,7 @@ pub fn cast_with_options( } (Decimal128(_, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal::( - as_primitive_array(array), + array.as_primitive(), *s1, *p2, *s2, @@ -743,7 +741,7 @@ pub fn cast_with_options( } (Decimal256(_, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal::( - as_primitive_array(array), + array.as_primitive(), *s1, *p2, *s2, @@ -888,69 +886,69 @@ pub fn cast_with_options( // cast data to decimal match from_type { UInt8 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), UInt16 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), UInt32 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), UInt64 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, 10_i128, cast_options, ), Float32 => cast_floating_point_to_decimal128( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, cast_options, ), Float64 => cast_floating_point_to_decimal128( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, cast_options, @@ -977,69 +975,69 @@ pub fn cast_with_options( // cast data to decimal match from_type { UInt8 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), UInt16 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), UInt32 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), UInt64 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, i256::from_i128(10_i128), cast_options, ), Float32 => cast_floating_point_to_decimal256( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, cast_options, ), Float64 => cast_floating_point_to_decimal256( - as_primitive_array::(array), + array.as_primitive::(), *precision, *scale, cast_options, @@ -1133,9 +1131,9 @@ pub fn cast_with_options( Float64 => cast_string_to_numeric::(array, cast_options), Date32 => cast_string_to_date32::(array, cast_options), Date64 => cast_string_to_date64::(array, cast_options), - Binary => Ok(Arc::new(BinaryArray::from(as_string_array(array).clone()))), + Binary => Ok(Arc::new(BinaryArray::from(array.as_string::().clone()))), LargeBinary => { - let binary = BinaryArray::from(as_string_array(array).clone()); + let binary = BinaryArray::from(array.as_string::().clone()); cast_byte_container::(&binary) } LargeUtf8 => cast_byte_container::(array), @@ -1192,11 +1190,11 @@ pub fn cast_with_options( Utf8 => cast_byte_container::(array), Binary => { let large_binary = - LargeBinaryArray::from(as_largestring_array(array).clone()); + LargeBinaryArray::from(array.as_string::().clone()); cast_byte_container::(&large_binary) } LargeBinary => Ok(Arc::new(LargeBinaryArray::from( - as_largestring_array(array).clone(), + array.as_string::().clone(), ))), Time32(TimeUnit::Second) => { cast_string_to_time32second::(array, cast_options) @@ -1580,71 +1578,71 @@ pub fn cast_with_options( cast_reinterpret_arrays::(array) } (Date32, Date64) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Date64Type>(|x| x as i64 * MILLISECONDS_IN_DAY), )), (Date64, Date32) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Date32Type>(|x| (x / MILLISECONDS_IN_DAY) as i32), )), (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time32MillisecondType>(|x| x * MILLISECONDS as i32), )), (Time32(TimeUnit::Second), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time64MicrosecondType>(|x| x as i64 * MICROSECONDS), )), (Time32(TimeUnit::Second), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time64NanosecondType>(|x| x as i64 * NANOSECONDS), )), (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time32SecondType>(|x| x / MILLISECONDS as i32), )), (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time64MicrosecondType>(|x| { x as i64 * (MICROSECONDS / MILLISECONDS) }), )), (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time64NanosecondType>(|x| { x as i64 * (MICROSECONDS / NANOSECONDS) }), )), (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Second)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time32SecondType>(|x| (x / MICROSECONDS) as i32), )), (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time32MillisecondType>(|x| { (x / (MICROSECONDS / MILLISECONDS)) as i32 }), )), (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time64NanosecondType>(|x| x * (NANOSECONDS / MICROSECONDS)), )), (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Second)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time32SecondType>(|x| (x / NANOSECONDS) as i32), )), (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time32MillisecondType>(|x| { (x / (NANOSECONDS / MILLISECONDS)) as i32 }), )), (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Time64MicrosecondType>(|x| x / (NANOSECONDS / MICROSECONDS)), )), @@ -1662,14 +1660,14 @@ pub fn cast_with_options( } (Int64, Timestamp(unit, tz)) => Ok(make_timestamp_array( - as_primitive_array(array), + array.as_primitive(), unit.clone(), tz.clone(), )), (Timestamp(from_unit, _), Timestamp(to_unit, to_tz)) => { let array = cast_with_options(array, &Int64, cast_options)?; - let time_array = as_primitive_array::(array.as_ref()); + let time_array = array.as_primitive::(); let from_size = time_unit_multiple(from_unit); let to_size = time_unit_multiple(to_unit); // we either divide or multiply, depending on size of each unit @@ -1697,7 +1695,7 @@ pub fn cast_with_options( } (Timestamp(from_unit, _), Date32) => { let array = cast_with_options(array, &Int64, cast_options)?; - let time_array = as_primitive_array::(array.as_ref()); + let time_array = array.as_primitive::(); let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; let mut b = Date32Builder::with_capacity(array.len()); @@ -1716,13 +1714,13 @@ pub fn cast_with_options( match cast_options.safe { true => { // change error to None - as_primitive_array::(array) + array.as_primitive::() .unary_opt::<_, Date64Type>(|x| { x.checked_mul(MILLISECONDS) }) } false => { - as_primitive_array::(array).try_unary::<_, Date64Type, _>( + array.as_primitive::().try_unary::<_, Date64Type, _>( |x| { x.mul_checked(MILLISECONDS) }, @@ -1734,17 +1732,17 @@ pub fn cast_with_options( cast_reinterpret_arrays::(array) } (Timestamp(TimeUnit::Microsecond, _), Date64) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Date64Type>(|x| x / (MICROSECONDS / MILLISECONDS)), )), (Timestamp(TimeUnit::Nanosecond, _), Date64) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, Date64Type>(|x| x / (NANOSECONDS / MILLISECONDS)), )), (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampSecondType, @@ -1755,7 +1753,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampSecondType, @@ -1766,7 +1764,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampMillisecondType, @@ -1777,7 +1775,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampMillisecondType, @@ -1788,7 +1786,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -1799,7 +1797,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -1810,7 +1808,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampNanosecondType, @@ -1821,7 +1819,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampNanosecondType, @@ -1832,7 +1830,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampSecondType, @@ -1843,7 +1841,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampSecondType, @@ -1854,7 +1852,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampMillisecondType, @@ -1865,7 +1863,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampMillisecondType, @@ -1876,7 +1874,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -1887,7 +1885,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -1898,7 +1896,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampNanosecondType, @@ -1909,7 +1907,7 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampNanosecondType, @@ -1919,38 +1917,38 @@ pub fn cast_with_options( } (Date64, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, TimestampSecondType>(|x| x / MILLISECONDS), )), (Date64, Timestamp(TimeUnit::Millisecond, None)) => { cast_reinterpret_arrays::(array) } (Date64, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( - as_primitive_array::(array).unary::<_, TimestampMicrosecondType>( + array.as_primitive::().unary::<_, TimestampMicrosecondType>( |x| x * (MICROSECONDS / MILLISECONDS), ), )), (Date64, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( - as_primitive_array::(array).unary::<_, TimestampNanosecondType>( + array.as_primitive::().unary::<_, TimestampNanosecondType>( |x| x * (NANOSECONDS / MILLISECONDS), ), )), (Date32, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, TimestampSecondType>(|x| (x as i64) * SECONDS_IN_DAY), )), (Date32, Timestamp(TimeUnit::Millisecond, None)) => Ok(Arc::new( - as_primitive_array::(array).unary::<_, TimestampMillisecondType>( + array.as_primitive::().unary::<_, TimestampMillisecondType>( |x| (x as i64) * MILLISECONDS_IN_DAY, ), )), (Date32, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( - as_primitive_array::(array).unary::<_, TimestampMicrosecondType>( + array.as_primitive::().unary::<_, TimestampMicrosecondType>( |x| (x as i64) * MICROSECONDS_IN_DAY, ), )), (Date32, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( - as_primitive_array::(array) + array.as_primitive::() .unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY), )), (Int64, Duration(TimeUnit::Second)) => { @@ -3736,7 +3734,7 @@ mod tests { let result = cast(&array, &DataType::Decimal128(2, 2)); assert!(result.is_ok()); let array = result.unwrap(); - let array: &Decimal128Array = as_primitive_array(&array); + let array: &Decimal128Array = array.as_primitive(); let err = array.validate_decimal_precision(2); assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", err.unwrap_err().to_string()); @@ -4306,7 +4304,7 @@ mod tests { let casted_array = cast(&array, &DataType::Decimal128(3, 1)); assert!(casted_array.is_ok()); let array = casted_array.unwrap(); - let array: &Decimal128Array = as_primitive_array(&array); + let array: &Decimal128Array = array.as_primitive(); let err = array.validate_decimal_precision(3); assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal128 of precision 3. Max is 999", err.unwrap_err().to_string()); @@ -4316,7 +4314,7 @@ mod tests { let casted_array = cast(&array, &DataType::Decimal128(3, 1)); assert!(casted_array.is_ok()); let array = casted_array.unwrap(); - let array: &Decimal128Array = as_primitive_array(&array); + let array: &Decimal128Array = array.as_primitive(); let err = array.validate_decimal_precision(3); assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal128 of precision 3. Max is 999", err.unwrap_err().to_string()); @@ -4475,7 +4473,7 @@ mod tests { let casted_array = cast(&array, &DataType::Decimal256(3, 1)); assert!(casted_array.is_ok()); let array = casted_array.unwrap(); - let array: &Decimal256Array = as_primitive_array(&array); + let array: &Decimal256Array = array.as_primitive(); let err = array.validate_decimal_precision(3); assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal256 of precision 3. Max is 999", err.unwrap_err().to_string()); @@ -4603,14 +4601,14 @@ mod tests { ) .unwrap(); assert_eq!(5, b.len()); - let arr = b.as_any().downcast_ref::().unwrap(); + let arr = b.as_list::(); assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); assert_eq!(1, arr.value_length(0)); assert_eq!(1, arr.value_length(1)); assert_eq!(1, arr.value_length(2)); assert_eq!(1, arr.value_length(3)); assert_eq!(1, arr.value_length(4)); - let c = as_primitive_array::(arr.values()); + let c = arr.values().as_primitive::(); assert_eq!(5, c.value(0)); assert_eq!(6, c.value(1)); assert_eq!(7, c.value(2)); @@ -4628,7 +4626,7 @@ mod tests { .unwrap(); assert_eq!(5, b.len()); assert_eq!(1, b.null_count()); - let arr = b.as_any().downcast_ref::().unwrap(); + let arr = b.as_list::(); assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); assert_eq!(1, arr.value_length(0)); assert_eq!(1, arr.value_length(1)); @@ -4636,7 +4634,7 @@ mod tests { assert_eq!(1, arr.value_length(3)); assert_eq!(1, arr.value_length(4)); - let c = as_primitive_array::(arr.values()); + let c = arr.values().as_primitive::(); assert_eq!(1, c.null_count()); assert_eq!(5, c.value(0)); assert!(!c.is_valid(1)); @@ -4657,13 +4655,13 @@ mod tests { .unwrap(); assert_eq!(4, b.len()); assert_eq!(1, b.null_count()); - let arr = b.as_any().downcast_ref::().unwrap(); + let arr = b.as_list::(); assert_eq!(&[0, 1, 2, 3, 4], arr.value_offsets()); assert_eq!(1, arr.value_length(0)); assert_eq!(1, arr.value_length(1)); assert_eq!(1, arr.value_length(2)); assert_eq!(1, arr.value_length(3)); - let c = as_primitive_array::(arr.values()); + let c = arr.values().as_primitive::(); assert_eq!(1, c.null_count()); assert_eq!(7.0, c.value(0)); assert_eq!(8.0, c.value(1)); @@ -4802,7 +4800,7 @@ mod tests { assert_eq!(2, array.value_length(2)); // expect 4 nulls: negative numbers and overflow - let u16arr = as_primitive_array::(array.values()); + let u16arr = array.values().as_primitive::(); assert_eq!(4, u16arr.null_count()); // expect 4 nulls: negative numbers and overflow @@ -6946,7 +6944,7 @@ mod tests { let expected = $ARR_TYPE::from(vec![None; 6]); let cast_type = DataType::$DATATYPE; let cast_array = cast(&array, &cast_type).expect("cast failed"); - let cast_array = as_primitive_array::<$TYPE>(&cast_array); + let cast_array = cast_array.as_primitive::<$TYPE>(); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(cast_array, &expected); } @@ -7439,7 +7437,7 @@ mod tests { ); let casted_array = cast(&array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("1123450", decimal_arr.value_as_string(0)); assert_eq!("2123460", decimal_arr.value_as_string(1)); @@ -7456,7 +7454,7 @@ mod tests { ])) as ArrayRef; let casted_array = cast(&array, &decimal_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("1123450", decimal_arr.value_as_string(0)); assert_eq!("2123450", decimal_arr.value_as_string(1)); @@ -7469,7 +7467,7 @@ mod tests { ])) as ArrayRef; let casted_array = cast(&array, &decimal_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("1120", decimal_arr.value_as_string(0)); assert_eq!("2120", decimal_arr.value_as_string(1)); @@ -7492,7 +7490,7 @@ mod tests { ); let casted_array = cast(&array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("1200", decimal_arr.value_as_string(0)); @@ -7507,7 +7505,7 @@ mod tests { ); let casted_array = cast(&array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("1300", decimal_arr.value_as_string(0)); } @@ -7632,7 +7630,7 @@ mod tests { assert!(can_cast_types(array.data_type(), &output_type)); let casted_array = cast(&array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("123.45", decimal_arr.value_as_string(0)); assert_eq!("1.23", decimal_arr.value_as_string(1)); @@ -7653,7 +7651,7 @@ mod tests { assert!(can_cast_types(array.data_type(), &output_type)); let casted_array = cast(&array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!("123.450", decimal_arr.value_as_string(0)); assert_eq!("1.235", decimal_arr.value_as_string(1)); @@ -7751,7 +7749,7 @@ mod tests { fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { let output_type = DataType::Decimal128(38, 2); let casted_array = cast(&overflow_array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert!(decimal_arr.is_null(0)); assert!(decimal_arr.is_null(1)); @@ -7797,7 +7795,7 @@ mod tests { fn test_cast_string_to_decimal256_overflow(overflow_array: ArrayRef) { let output_type = DataType::Decimal256(76, 2); let casted_array = cast(&overflow_array, &output_type).unwrap(); - let decimal_arr = as_primitive_array::(&casted_array); + let decimal_arr = casted_array.as_primitive::(); assert_eq!( "170141183460469231731687303715884105727.00", @@ -7916,7 +7914,7 @@ mod tests { ]); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); - let v = as_primitive_array::(b.as_ref()); + let v = b.as_primitive::(); assert_eq!(v.value(0), 946728000000000000); assert_eq!(v.value(1), 1608035696000000000); @@ -7926,7 +7924,7 @@ mod tests { &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".to_string())), ) .unwrap(); - let v = as_primitive_array::(b.as_ref()); + let v = b.as_primitive::(); assert_eq!(v.value(0), 946728000000000000); assert_eq!(v.value(1), 1608035696000000000); @@ -7936,7 +7934,7 @@ mod tests { &DataType::Timestamp(TimeUnit::Millisecond, Some("+02:00".to_string())), ) .unwrap(); - let v = as_primitive_array::(b.as_ref()); + let v = b.as_primitive::(); assert_eq!(v.value(0), 946728000000); assert_eq!(v.value(1), 1608035696000); @@ -7991,7 +7989,7 @@ mod tests { let s = BinaryArray::from(vec![v1, v2]); let options = CastOptions { safe: true }; let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap(); - let a = as_string_array(array.as_ref()); + let a = array.as_string::(); a.data().validate_full().unwrap(); assert_eq!(a.null_count(), 1); diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs index 6e06a0e39dc0..c8025f000eab 100644 --- a/arrow-cast/src/display.rs +++ b/arrow-cast/src/display.rs @@ -258,10 +258,10 @@ fn make_formatter<'a>( array => array_format(array, options), DataType::Null => array_format(as_null_array(array), options), DataType::Boolean => array_format(as_boolean_array(array), options), - DataType::Utf8 => array_format(as_string_array(array), options), - DataType::LargeUtf8 => array_format(as_largestring_array(array), options), - DataType::Binary => array_format(as_generic_binary_array::(array), options), - DataType::LargeBinary => array_format(as_generic_binary_array::(array), options), + DataType::Utf8 => array_format(array.as_string::(), options), + DataType::LargeUtf8 => array_format(array.as_string::(), options), + DataType::Binary => array_format(array.as_binary::(), options), + DataType::LargeBinary => array_format(array.as_binary::(), options), DataType::FixedSizeBinary(_) => { let a = array.as_any().downcast_ref::().unwrap(); array_format(a, options) diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 8b1cd2f79930..046bfafc4641 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -1146,7 +1146,7 @@ mod tests { use std::io::{Cursor, Write}; use tempfile::NamedTempFile; - use arrow_array::cast::as_boolean_array; + use arrow_array::cast::AsArray; use chrono::prelude::*; #[test] @@ -2059,14 +2059,14 @@ mod tests { assert_eq!(b.num_rows(), 4); assert_eq!(b.num_columns(), 2); - let c = as_boolean_array(b.column(0)); + let c = b.column(0).as_boolean(); assert_eq!(c.null_count(), 1); assert!(c.value(0)); assert!(!c.value(1)); assert!(c.is_null(2)); assert!(!c.value(3)); - let c = as_boolean_array(b.column(1)); + let c = b.column(1).as_boolean(); assert_eq!(c.null_count(), 1); assert!(!c.value(0)); assert!(c.value(1)); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index b57692749878..2d859f608387 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -2024,7 +2024,7 @@ mod tests { ); let sliced = array.slice(1, 2); - let read_sliced: &UInt32Array = as_primitive_array(&sliced); + let read_sliced: &UInt32Array = sliced.as_primitive(); assert_eq!( vec![Some(2), Some(3)], read_sliced.iter().collect::>() @@ -2044,7 +2044,7 @@ mod tests { let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader"); let read_batch = reader.next().unwrap().expect("read batch"); - let read_array: &UInt32Array = as_primitive_array(read_batch.column(0)); + let read_array: &UInt32Array = read_batch.column(0).as_primitive(); assert_eq!( vec![Some(2), Some(3)], read_array.iter().collect::>() diff --git a/arrow-json/src/raw/mod.rs b/arrow-json/src/raw/mod.rs index 57bec9ee49c0..2e5055bf149e 100644 --- a/arrow-json/src/raw/mod.rs +++ b/arrow-json/src/raw/mod.rs @@ -357,10 +357,7 @@ mod tests { use super::*; use crate::reader::infer_json_schema; use crate::ReaderBuilder; - use arrow_array::cast::{ - as_boolean_array, as_largestring_array, as_list_array, as_map_array, - as_primitive_array, as_string_array, as_struct_array, - }; + use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; use arrow_array::Array; use arrow_buffer::ArrowNativeType; @@ -431,29 +428,29 @@ mod tests { let batches = do_read(buf, 1024, false, schema); assert_eq!(batches.len(), 1); - let col1 = as_primitive_array::(batches[0].column(0)); + let col1 = batches[0].column(0).as_primitive::(); assert_eq!(col1.null_count(), 2); assert_eq!(col1.values(), &[1, 2, 2, 2, 0, 0]); assert!(col1.is_null(4)); assert!(col1.is_null(5)); - let col2 = as_primitive_array::(batches[0].column(1)); + let col2 = batches[0].column(1).as_primitive::(); assert_eq!(col2.null_count(), 0); assert_eq!(col2.values(), &[2, 4, 6, 5, 4, 7]); - let col3 = as_boolean_array(batches[0].column(2)); + let col3 = batches[0].column(2).as_boolean(); assert_eq!(col3.null_count(), 4); assert!(col3.value(0)); assert!(!col3.is_null(0)); assert!(!col3.value(1)); assert!(!col3.is_null(1)); - let col4 = as_primitive_array::(batches[0].column(3)); + let col4 = batches[0].column(3).as_primitive::(); assert_eq!(col4.null_count(), 3); assert!(col4.is_null(3)); assert_eq!(col4.values(), &[1, 2, 45, 0, 0, 0]); - let col5 = as_primitive_array::(batches[0].column(4)); + let col5 = batches[0].column(4).as_primitive::(); assert_eq!(col5.null_count(), 5); assert!(col5.is_null(0)); assert!(col5.is_null(2)); @@ -480,7 +477,7 @@ mod tests { let batches = do_read(buf, 1024, false, schema); assert_eq!(batches.len(), 1); - let col1 = as_string_array(batches[0].column(0)); + let col1 = batches[0].column(0).as_string::(); assert_eq!(col1.null_count(), 2); assert_eq!(col1.value(0), "1"); assert_eq!(col1.value(1), "hello"); @@ -488,7 +485,7 @@ mod tests { assert!(col1.is_null(3)); assert!(col1.is_null(4)); - let col2 = as_largestring_array(batches[0].column(1)); + let col2 = batches[0].column(1).as_string::(); assert_eq!(col2.null_count(), 1); assert_eq!(col2.value(0), "2"); assert_eq!(col2.value(1), "shoo"); @@ -537,41 +534,41 @@ mod tests { let batches = do_read(buf, 1024, false, schema); assert_eq!(batches.len(), 1); - let list = as_list_array(batches[0].column(0).as_ref()); + let list = batches[0].column(0).as_list::(); assert_eq!(list.len(), 3); assert_eq!(list.value_offsets(), &[0, 0, 2, 2]); assert_eq!(list.null_count(), 1); assert!(list.is_null(2)); - let list_values = as_primitive_array::(list.values().as_ref()); + let list_values = list.values().as_primitive::(); assert_eq!(list_values.values(), &[5, 6]); - let nested = as_struct_array(batches[0].column(1).as_ref()); - let a = as_primitive_array::(nested.column(0).as_ref()); + let nested = batches[0].column(1).as_struct(); + let a = nested.column(0).as_primitive::(); assert_eq!(list.null_count(), 1); assert_eq!(a.values(), &[1, 7, 0]); assert!(list.is_null(2)); - let b = as_primitive_array::(nested.column(1).as_ref()); + let b = nested.column(1).as_primitive::(); assert_eq!(b.null_count(), 2); assert_eq!(b.len(), 3); assert_eq!(b.value(0), 2); assert!(b.is_null(1)); assert!(b.is_null(2)); - let nested_list = as_struct_array(batches[0].column(2).as_ref()); + let nested_list = batches[0].column(2).as_struct(); assert_eq!(nested_list.len(), 3); assert_eq!(nested_list.null_count(), 1); assert!(nested_list.is_null(2)); - let list2 = as_list_array(nested_list.column(0).as_ref()); + let list2 = nested_list.column(0).as_list::(); assert_eq!(list2.len(), 3); assert_eq!(list2.null_count(), 1); assert_eq!(list2.value_offsets(), &[0, 2, 2, 2]); assert!(list2.is_null(2)); - let list2_values = as_struct_array(list2.values().as_ref()); + let list2_values = list2.values().as_struct(); - let c = as_primitive_array::(list2_values.column(0)); + let c = list2_values.column(0).as_primitive::(); assert_eq!(c.values(), &[3, 4]); } @@ -606,26 +603,26 @@ mod tests { let batches = do_read(buf, 1024, false, schema); assert_eq!(batches.len(), 1); - let nested = as_struct_array(batches[0].column(0).as_ref()); + let nested = batches[0].column(0).as_struct(); assert_eq!(nested.num_columns(), 1); - let a = as_primitive_array::(nested.column(0).as_ref()); + let a = nested.column(0).as_primitive::(); assert_eq!(a.null_count(), 0); assert_eq!(a.values(), &[1, 7]); - let nested_list = as_struct_array(batches[0].column(1).as_ref()); + let nested_list = batches[0].column(1).as_struct(); assert_eq!(nested_list.num_columns(), 1); assert_eq!(nested_list.null_count(), 0); - let list2 = as_list_array(nested_list.column(0).as_ref()); + let list2 = nested_list.column(0).as_list::(); assert_eq!(list2.value_offsets(), &[0, 2, 2]); assert_eq!(list2.null_count(), 0); - let child = as_struct_array(list2.values().as_ref()); + let child = list2.values().as_struct(); assert_eq!(child.num_columns(), 1); assert_eq!(child.len(), 2); assert_eq!(child.null_count(), 0); - let c = as_primitive_array::(child.column(0).as_ref()); + let c = child.column(0).as_primitive::(); assert_eq!(c.values(), &[5, 0]); assert_eq!(c.null_count(), 1); assert!(c.is_null(1)); @@ -650,15 +647,15 @@ mod tests { let batches = do_read(buf, 1024, false, schema); assert_eq!(batches.len(), 1); - let map = as_map_array(batches[0].column(0).as_ref()); - let map_keys = as_string_array(map.keys().as_ref()); - let map_values = as_list_array(map.values().as_ref()); + let map = batches[0].column(0).as_map(); + let map_keys = map.keys().as_string::(); + let map_values = map.values().as_list::(); assert_eq!(map.value_offsets(), &[0, 1, 3, 5]); let k: Vec<_> = map_keys.iter().map(|x| x.unwrap()).collect(); assert_eq!(&k, &["a", "a", "b", "c", "a"]); - let list_values = as_string_array(map_values.values().as_ref()); + let list_values = map_values.values().as_string::(); let lv: Vec<_> = list_values.iter().collect(); assert_eq!(&lv, &[Some("foo"), None, None, Some("baz")]); assert_eq!(map_values.value_offsets(), &[0, 2, 3, 3, 3, 4]); @@ -751,7 +748,7 @@ mod tests { let batches = do_read(buf, 1024, true, schema); assert_eq!(batches.len(), 1); - let col1 = as_string_array(batches[0].column(0)); + let col1 = batches[0].column(0).as_string::(); assert_eq!(col1.null_count(), 2); assert_eq!(col1.value(0), "1"); assert_eq!(col1.value(1), "2E0"); @@ -760,7 +757,7 @@ mod tests { assert!(col1.is_null(4)); assert!(col1.is_null(5)); - let col2 = as_string_array(batches[0].column(1)); + let col2 = batches[0].column(1).as_string::(); assert_eq!(col2.null_count(), 0); assert_eq!(col2.value(0), "2"); assert_eq!(col2.value(1), "4"); @@ -769,7 +766,7 @@ mod tests { assert_eq!(col2.value(4), "4e0"); assert_eq!(col2.value(5), "7"); - let col3 = as_string_array(batches[0].column(2)); + let col3 = batches[0].column(2).as_string::(); assert_eq!(col3.null_count(), 4); assert_eq!(col3.value(0), "true"); assert_eq!(col3.value(1), "false"); @@ -799,7 +796,7 @@ mod tests { let batches = do_read(buf, 1024, true, schema); assert_eq!(batches.len(), 1); - let col1 = as_primitive_array::(batches[0].column(0)); + let col1 = batches[0].column(0).as_primitive::(); assert_eq!(col1.null_count(), 2); assert!(col1.is_null(4)); assert!(col1.is_null(5)); @@ -808,14 +805,14 @@ mod tests { &[100, 200, 204, 1103420, 0, 0].map(T::Native::usize_as) ); - let col2 = as_primitive_array::(batches[0].column(1)); + let col2 = batches[0].column(1).as_primitive::(); assert_eq!(col2.null_count(), 0); assert_eq!( col2.values(), &[200, 400, 133700, 500, 4000, 123400].map(T::Native::usize_as) ); - let col3 = as_primitive_array::(batches[0].column(2)); + let col3 = batches[0].column(2).as_primitive::(); assert_eq!(col3.null_count(), 4); assert!(!col3.is_null(0)); assert!(!col3.is_null(1)); @@ -864,7 +861,7 @@ mod tests { TimeUnit::Nanosecond => 1, }; - let col1 = as_primitive_array::(batches[0].column(0)); + let col1 = batches[0].column(0).as_primitive::(); assert_eq!(col1.null_count(), 4); assert!(col1.is_null(2)); assert!(col1.is_null(3)); @@ -872,7 +869,7 @@ mod tests { assert!(col1.is_null(5)); assert_eq!(col1.values(), &[1, 2, 0, 0, 0, 0].map(T::Native::usize_as)); - let col2 = as_primitive_array::(batches[0].column(1)); + let col2 = batches[0].column(1).as_primitive::(); assert_eq!(col2.null_count(), 1); assert!(col2.is_null(5)); assert_eq!( @@ -887,7 +884,7 @@ mod tests { ] ); - let col3 = as_primitive_array::(batches[0].column(2)); + let col3 = batches[0].column(2).as_primitive::(); assert_eq!(col3.null_count(), 0); assert_eq!( col3.values(), @@ -901,7 +898,7 @@ mod tests { ] ); - let col4 = as_primitive_array::(batches[0].column(3)); + let col4 = batches[0].column(3).as_primitive::(); assert_eq!(col4.null_count(), 0); assert_eq!( @@ -957,7 +954,7 @@ mod tests { let batches = do_read(buf, 1024, true, schema); assert_eq!(batches.len(), 1); - let col1 = as_primitive_array::(batches[0].column(0)); + let col1 = batches[0].column(0).as_primitive::(); assert_eq!(col1.null_count(), 4); assert!(col1.is_null(2)); assert!(col1.is_null(3)); @@ -965,7 +962,7 @@ mod tests { assert!(col1.is_null(5)); assert_eq!(col1.values(), &[1, 2, 0, 0, 0, 0].map(T::Native::usize_as)); - let col2 = as_primitive_array::(batches[0].column(1)); + let col2 = batches[0].column(1).as_primitive::(); assert_eq!(col2.null_count(), 1); assert!(col2.is_null(5)); assert_eq!( @@ -981,7 +978,7 @@ mod tests { .map(T::Native::usize_as) ); - let col3 = as_primitive_array::(batches[0].column(2)); + let col3 = batches[0].column(2).as_primitive::(); assert_eq!(col3.null_count(), 0); assert_eq!( col3.values(), diff --git a/arrow-json/src/reader.rs b/arrow-json/src/reader.rs index 5d86f9a578c2..8e33613886f1 100644 --- a/arrow-json/src/reader.rs +++ b/arrow-json/src/reader.rs @@ -1844,10 +1844,7 @@ impl Iterator for Reader { #[allow(deprecated)] mod tests { use super::*; - use arrow_array::cast::{ - as_boolean_array, as_dictionary_array, as_primitive_array, as_string_array, - as_struct_array, - }; + use arrow_array::cast::AsArray; use arrow_buffer::{ArrowNativeType, ToByteSlice}; use arrow_schema::DataType::{Dictionary, List}; use flate2::read::GzDecoder; @@ -2133,20 +2130,12 @@ mod tests { let d = schema.column_with_name("d").unwrap(); assert_eq!(&DataType::Utf8, d.1.data_type()); - let aa = batch - .column(a.0) - .as_any() - .downcast_ref::() - .unwrap(); + let aa = batch.column(a.0).as_primitive::(); assert_eq!(1, aa.value(0)); assert_eq!(-10, aa.value(1)); assert_eq!(1627668684594000000, aa.value(2)); - let bb = batch - .column(b.0) - .as_any() - .downcast_ref::() - .unwrap(); - let bb = as_primitive_array::(bb.values()); + let bb = batch.column(b.0).as_list::(); + let bb = bb.values().as_primitive::(); assert_eq!(9, bb.len()); assert_eq!(2.0, bb.value(0)); assert_eq!(-6.1, bb.value(5)); @@ -2157,7 +2146,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - let cc = as_boolean_array(cc.values()); + let cc = cc.values().as_boolean(); assert_eq!(6, cc.len()); assert!(!cc.value(0)); assert!(!cc.value(4)); @@ -2271,7 +2260,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - let bb = as_primitive_array::(bb.values()); + let bb = bb.values().as_primitive::(); assert_eq!(10, bb.len()); assert_eq!(4.0, bb.value(9)); @@ -2285,7 +2274,7 @@ mod tests { *cc.data().buffers()[0], Buffer::from_slice_ref([0i32, 2, 2, 4, 5]) ); - let cc = as_boolean_array(cc.values()); + let cc = cc.values().as_boolean(); let cc_expected = BooleanArray::from(vec![ Some(false), Some(true), @@ -2306,7 +2295,7 @@ mod tests { Buffer::from_slice_ref([0i32, 1, 1, 2, 6]) ); - let dd = as_string_array(dd.values()); + let dd = dd.values().as_string::(); // values are 6 because a `d: null` is treated as a null slot // and a list's null slot can be omitted from the child (i.e. same offset) assert_eq!(6, dd.len()); @@ -2452,8 +2441,8 @@ mod tests { // compare list null buffers assert_eq!(read.nulls(), expected.nulls()); // build struct from list - let struct_array = as_struct_array(read.values()); - let expected_struct_array = as_struct_array(expected.values()); + let struct_array = read.values().as_struct(); + let expected_struct_array = expected.values().as_struct(); assert_eq!(7, struct_array.len()); assert_eq!(1, struct_array.null_count()); @@ -2767,14 +2756,13 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - let evs_list = as_dictionary_array::(evs_list.values()); + let evs_list = evs_list.values().as_dictionary::(); assert_eq!(6, evs_list.len()); assert!(evs_list.is_valid(1)); assert_eq!(DataType::Utf8, evs_list.value_type()); // dict from the events list - let dict_el = evs_list.values(); - let dict_el = dict_el.as_any().downcast_ref::().unwrap(); + let dict_el = evs_list.values().as_string::(); assert_eq!(3, dict_el.len()); assert_eq!("Elect Leader", dict_el.value(0)); assert_eq!("Do Ballot", dict_el.value(1)); @@ -2824,7 +2812,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - let evs_list = as_dictionary_array::(evs_list.values()); + let evs_list = evs_list.values().as_dictionary::(); assert_eq!(8, evs_list.len()); assert!(evs_list.is_valid(1)); assert_eq!(DataType::Utf8, evs_list.value_type()); diff --git a/arrow-json/src/writer.rs b/arrow-json/src/writer.rs index 27ae3876441d..bbc04c9dc096 100644 --- a/arrow-json/src/writer.rs +++ b/arrow-json/src/writer.rs @@ -112,7 +112,8 @@ where T: ArrowPrimitiveType, T::Native: JsonSerializable, { - Ok(as_primitive_array::(array) + Ok(array + .as_primitive::() .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into_json_value().unwrap_or(Value::Null), @@ -146,7 +147,8 @@ fn struct_array_to_jsonmap_array( pub fn array_to_json_array(array: &ArrayRef) -> Result, ArrowError> { match array.data_type() { DataType::Null => Ok(iter::repeat(Value::Null).take(array.len()).collect()), - DataType::Boolean => Ok(as_boolean_array(array) + DataType::Boolean => Ok(array + .as_boolean() .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into(), @@ -154,14 +156,16 @@ pub fn array_to_json_array(array: &ArrayRef) -> Result, ArrowError> { }) .collect()), - DataType::Utf8 => Ok(as_string_array(array) + DataType::Utf8 => Ok(array + .as_string::() .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into(), None => Value::Null, }) .collect()), - DataType::LargeUtf8 => Ok(as_largestring_array(array) + DataType::LargeUtf8 => Ok(array + .as_string::() .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into(), @@ -225,7 +229,7 @@ fn set_column_by_primitive_type( T: ArrowPrimitiveType, T::Native: JsonSerializable, { - let primitive_arr = as_primitive_array::(array); + let primitive_arr = array.as_primitive::(); rows.iter_mut() .zip(primitive_arr.iter()) @@ -369,7 +373,7 @@ fn set_column_for_json_rows( ))); } - let keys = as_string_array(keys); + let keys = keys.as_string::(); let values = array_to_json_array(values)?; let mut kv = keys.iter().zip(values.into_iter()); diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index eb672e769ac3..0f9414378c4a 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -829,14 +829,8 @@ pub fn eq_dyn_binary_scalar( right: &[u8], ) -> Result { match left.data_type() { - DataType::Binary => { - let left = as_generic_binary_array::(left); - eq_binary_scalar(left, right) - } - DataType::LargeBinary => { - let left = as_generic_binary_array::(left); - eq_binary_scalar(left, right) - } + DataType::Binary => eq_binary_scalar(left.as_binary::(), right), + DataType::LargeBinary => eq_binary_scalar(left.as_binary::(), right), _ => Err(ArrowError::ComputeError( "eq_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), )), @@ -850,14 +844,8 @@ pub fn neq_dyn_binary_scalar( right: &[u8], ) -> Result { match left.data_type() { - DataType::Binary => { - let left = as_generic_binary_array::(left); - neq_binary_scalar(left, right) - } - DataType::LargeBinary => { - let left = as_generic_binary_array::(left); - neq_binary_scalar(left, right) - } + DataType::Binary => neq_binary_scalar(left.as_binary::(), right), + DataType::LargeBinary => neq_binary_scalar(left.as_binary::(), right), _ => Err(ArrowError::ComputeError( "neq_dyn_binary_scalar only supports Binary or LargeBinary arrays" .to_string(), @@ -872,14 +860,8 @@ pub fn lt_dyn_binary_scalar( right: &[u8], ) -> Result { match left.data_type() { - DataType::Binary => { - let left = as_generic_binary_array::(left); - lt_binary_scalar(left, right) - } - DataType::LargeBinary => { - let left = as_generic_binary_array::(left); - lt_binary_scalar(left, right) - } + DataType::Binary => lt_binary_scalar(left.as_binary::(), right), + DataType::LargeBinary => lt_binary_scalar(left.as_binary::(), right), _ => Err(ArrowError::ComputeError( "lt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), )), @@ -893,14 +875,8 @@ pub fn lt_eq_dyn_binary_scalar( right: &[u8], ) -> Result { match left.data_type() { - DataType::Binary => { - let left = as_generic_binary_array::(left); - lt_eq_binary_scalar(left, right) - } - DataType::LargeBinary => { - let left = as_generic_binary_array::(left); - lt_eq_binary_scalar(left, right) - } + DataType::Binary => lt_eq_binary_scalar(left.as_binary::(), right), + DataType::LargeBinary => lt_eq_binary_scalar(left.as_binary::(), right), _ => Err(ArrowError::ComputeError( "lt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" .to_string(), @@ -915,14 +891,8 @@ pub fn gt_dyn_binary_scalar( right: &[u8], ) -> Result { match left.data_type() { - DataType::Binary => { - let left = as_generic_binary_array::(left); - gt_binary_scalar(left, right) - } - DataType::LargeBinary => { - let left = as_generic_binary_array::(left); - gt_binary_scalar(left, right) - } + DataType::Binary => gt_binary_scalar(left.as_binary::(), right), + DataType::LargeBinary => gt_binary_scalar(left.as_binary::(), right), _ => Err(ArrowError::ComputeError( "gt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), )), @@ -936,14 +906,8 @@ pub fn gt_eq_dyn_binary_scalar( right: &[u8], ) -> Result { match left.data_type() { - DataType::Binary => { - let left = as_generic_binary_array::(left); - gt_eq_binary_scalar(left, right) - } - DataType::LargeBinary => { - let left = as_generic_binary_array::(left); - gt_eq_binary_scalar(left, right) - } + DataType::Binary => gt_eq_binary_scalar(left.as_binary::(), right), + DataType::LargeBinary => gt_eq_binary_scalar(left.as_binary::(), right), _ => Err(ArrowError::ComputeError( "gt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" .to_string(), @@ -967,12 +931,10 @@ pub fn eq_dyn_utf8_scalar( )), }, DataType::Utf8 => { - let left = as_string_array(left); - eq_utf8_scalar(left, right) + eq_utf8_scalar(left.as_string::(), right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); - eq_utf8_scalar(left, right) + eq_utf8_scalar(left.as_string::(), right) } _ => Err(ArrowError::ComputeError( "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), @@ -997,12 +959,10 @@ pub fn lt_dyn_utf8_scalar( )), }, DataType::Utf8 => { - let left = as_string_array(left); - lt_utf8_scalar(left, right) + lt_utf8_scalar(left.as_string::(), right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); - lt_utf8_scalar(left, right) + lt_utf8_scalar(left.as_string::(), right) } _ => Err(ArrowError::ComputeError( "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), @@ -1027,12 +987,10 @@ pub fn gt_eq_dyn_utf8_scalar( )), }, DataType::Utf8 => { - let left = as_string_array(left); - gt_eq_utf8_scalar(left, right) + gt_eq_utf8_scalar(left.as_string::(), right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); - gt_eq_utf8_scalar(left, right) + gt_eq_utf8_scalar(left.as_string::(), right) } _ => Err(ArrowError::ComputeError( "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), @@ -1057,12 +1015,10 @@ pub fn lt_eq_dyn_utf8_scalar( )), }, DataType::Utf8 => { - let left = as_string_array(left); - lt_eq_utf8_scalar(left, right) + lt_eq_utf8_scalar(left.as_string::(), right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); - lt_eq_utf8_scalar(left, right) + lt_eq_utf8_scalar(left.as_string::(), right) } _ => Err(ArrowError::ComputeError( "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), @@ -1087,12 +1043,10 @@ pub fn gt_dyn_utf8_scalar( )), }, DataType::Utf8 => { - let left = as_string_array(left); - gt_utf8_scalar(left, right) + gt_utf8_scalar(left.as_string::(), right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); - gt_utf8_scalar(left, right) + gt_utf8_scalar(left.as_string::(), right) } _ => Err(ArrowError::ComputeError( "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), @@ -1117,12 +1071,10 @@ pub fn neq_dyn_utf8_scalar( )), }, DataType::Utf8 => { - let left = as_string_array(left); - neq_utf8_scalar(left, right) + neq_utf8_scalar(left.as_string::(), right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); - neq_utf8_scalar(left, right) + neq_utf8_scalar(left.as_string::(), right) } _ => Err(ArrowError::ComputeError( "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), @@ -1138,10 +1090,7 @@ pub fn eq_dyn_bool_scalar( right: bool, ) -> Result { let result = match left.data_type() { - DataType::Boolean => { - let left = as_boolean_array(left); - eq_bool_scalar(left, right) - } + DataType::Boolean => eq_bool_scalar(left.as_boolean(), right), _ => Err(ArrowError::ComputeError( "eq_dyn_bool_scalar only supports BooleanArray".to_string(), )), @@ -1156,10 +1105,7 @@ pub fn lt_dyn_bool_scalar( right: bool, ) -> Result { let result = match left.data_type() { - DataType::Boolean => { - let left = as_boolean_array(left); - lt_bool_scalar(left, right) - } + DataType::Boolean => lt_bool_scalar(left.as_boolean(), right), _ => Err(ArrowError::ComputeError( "lt_dyn_bool_scalar only supports BooleanArray".to_string(), )), @@ -1174,10 +1120,7 @@ pub fn gt_dyn_bool_scalar( right: bool, ) -> Result { let result = match left.data_type() { - DataType::Boolean => { - let left = as_boolean_array(left); - gt_bool_scalar(left, right) - } + DataType::Boolean => gt_bool_scalar(left.as_boolean(), right), _ => Err(ArrowError::ComputeError( "gt_dyn_bool_scalar only supports BooleanArray".to_string(), )), @@ -1192,10 +1135,7 @@ pub fn lt_eq_dyn_bool_scalar( right: bool, ) -> Result { let result = match left.data_type() { - DataType::Boolean => { - let left = as_boolean_array(left); - lt_eq_bool_scalar(left, right) - } + DataType::Boolean => lt_eq_bool_scalar(left.as_boolean(), right), _ => Err(ArrowError::ComputeError( "lt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), )), @@ -1210,10 +1150,7 @@ pub fn gt_eq_dyn_bool_scalar( right: bool, ) -> Result { let result = match left.data_type() { - DataType::Boolean => { - let left = as_boolean_array(left); - gt_eq_bool_scalar(left, right) - } + DataType::Boolean => gt_eq_bool_scalar(left.as_boolean(), right), _ => Err(ArrowError::ComputeError( "gt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), )), @@ -1228,10 +1165,7 @@ pub fn neq_dyn_bool_scalar( right: bool, ) -> Result { let result = match left.data_type() { - DataType::Boolean => { - let left = as_boolean_array(left); - neq_bool_scalar(left, right) - } + DataType::Boolean => neq_bool_scalar(left.as_boolean(), right), _ => Err(ArrowError::ComputeError( "neq_dyn_bool_scalar only supports BooleanArray".to_string(), )), @@ -1455,8 +1389,8 @@ fn cmp_primitive_array( where F: Fn(T::Native, T::Native) -> bool, { - let left_array = as_primitive_array::(left); - let right_array = as_primitive_array::(right); + let left_array = left.as_primitive::(); + let right_array = right.as_primitive::(); compare_op(left_array, right_array, op) } @@ -2036,7 +1970,7 @@ where { compare_op( left.downcast_dict::>().unwrap(), - as_primitive_array::(right), + right.as_primitive::(), op, ) } @@ -3046,7 +2980,7 @@ mod tests { fn test_primitive_array_eq_scalar_with_slice() { let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); let a = a.slice(1, 3); - let a: &Int32Array = as_primitive_array(&a); + let a: &Int32Array = a.as_primitive(); let a_eq = eq_scalar(a, 2).unwrap(); assert_eq!( a_eq, @@ -3848,7 +3782,7 @@ mod tests { vec![Some("hi"), None, Some("hello"), Some("world"), Some("")], ); let a = a.slice(1, 4); - let a = as_string_array(&a); + let a = a.as_string::(); let a_eq = eq_utf8_scalar(a, "hello").unwrap(); assert_eq!( a_eq, diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 0f248ee637b0..ab6460e835f9 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -489,7 +489,7 @@ where { // create tuples that are used for sorting let valids = { - let values = as_primitive_array::(values); + let values = values.as_primitive::(); value_indices .into_iter() .map(|index| (index, values.value(index as usize))) @@ -1043,7 +1043,7 @@ pub struct SortColumn { /// # use std::sync::Arc; /// # use arrow_array::{ArrayRef, StringArray, PrimitiveArray}; /// # use arrow_array::types::Int64Type; -/// # use arrow_array::cast::as_primitive_array; +/// # use arrow_array::cast::AsArray; /// # use arrow_ord::sort::{SortColumn, SortOptions, lexsort}; /// /// let sorted_columns = lexsort(&vec![ @@ -1072,7 +1072,7 @@ pub struct SortColumn { /// }, /// ], None).unwrap(); /// -/// assert_eq!(as_primitive_array::(&sorted_columns[0]).value(1), -64); +/// assert_eq!(sorted_columns[0].as_primitive::().value(1), -64); /// assert!(sorted_columns[0].is_null(0)); /// ``` /// diff --git a/arrow-row/src/dictionary.rs b/arrow-row/src/dictionary.rs index bacc116cade7..273b7439d0d1 100644 --- a/arrow-row/src/dictionary.rs +++ b/arrow-row/src/dictionary.rs @@ -45,11 +45,11 @@ pub fn compute_dictionary_mapping( interner.intern(iter) } DataType::Utf8 => { - let iter = as_string_array(values).iter().map(|x| x.map(|x| x.as_bytes())); + let iter = values.as_string::().iter().map(|x| x.map(|x| x.as_bytes())); interner.intern(iter) } DataType::LargeUtf8 => { - let iter = as_largestring_array(values).iter().map(|x| x.map(|x| x.as_bytes())); + let iter = values.as_string::().iter().map(|x| x.map(|x| x.as_bytes())); interner.intern(iter) } _ => unreachable!(), diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 2c1de68c1926..2f0defe5268a 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -52,7 +52,7 @@ //! # use std::sync::Arc; //! # use arrow_row::{RowConverter, SortField}; //! # use arrow_array::{ArrayRef, Int32Array, StringArray}; -//! # use arrow_array::cast::{as_primitive_array, as_string_array}; +//! # use arrow_array::cast::{AsArray, as_string_array}; //! # use arrow_array::types::Int32Type; //! # use arrow_schema::DataType; //! @@ -89,10 +89,10 @@ //! // Convert selection of rows back to arrays //! let selection = [rows.row(0), rows2.row(1), rows.row(2), rows2.row(0)]; //! let converted = converter.convert_rows(selection).unwrap(); -//! let c1 = as_primitive_array::(converted[0].as_ref()); +//! let c1 = converted[0].as_primitive::(); //! assert_eq!(c1.values(), &[-1, 4, 0, 3]); //! -//! let c2 = as_string_array(converted[1].as_ref()); +//! let c2 = converted[1].as_string::(); //! let c2_values: Vec<_> = c2.iter().flatten().collect(); //! assert_eq!(&c2_values, &["a", "f", "c", "e"]); //! ``` @@ -1078,13 +1078,13 @@ fn new_empty_rows(cols: &[ArrayRef], encoders: &[Encoder], config: RowConfig) -> .iter() .zip(lengths.iter_mut()) .for_each(|(slice, length)| *length += variable::encoded_len(slice)), - DataType::Utf8 => as_string_array(array) + DataType::Utf8 => array.as_string::() .iter() .zip(lengths.iter_mut()) .for_each(|(slice, length)| { *length += variable::encoded_len(slice.map(|x| x.as_bytes())) }), - DataType::LargeUtf8 => as_largestring_array(array) + DataType::LargeUtf8 => array.as_string::() .iter() .zip(lengths.iter_mut()) .for_each(|(slice, length)| { @@ -1189,7 +1189,7 @@ fn encode_column( downcast_primitive_array! { column => fixed::encode(out, column, opts), DataType::Null => {} - DataType::Boolean => fixed::encode(out, as_boolean_array(column), opts), + DataType::Boolean => fixed::encode(out, column.as_boolean(), opts), DataType::Binary => { variable::encode(out, as_generic_binary_array::(column).iter(), opts) } @@ -1198,12 +1198,12 @@ fn encode_column( } DataType::Utf8 => variable::encode( out, - as_string_array(column).iter().map(|x| x.map(|x| x.as_bytes())), + column.as_string::().iter().map(|x| x.map(|x| x.as_bytes())), opts, ), DataType::LargeUtf8 => variable::encode( out, - as_largestring_array(column) + column.as_string::() .iter() .map(|x| x.map(|x| x.as_bytes())), opts, diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 35c11970c0f6..6267d8ae0028 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow_array::builder::BooleanBufferBuilder; -use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array}; +use arrow_array::cast::AsArray; use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType}; use arrow_array::*; use arrow_buffer::bit_util; @@ -350,16 +350,16 @@ fn filter_array( Ok(Arc::new(filter_boolean(values, predicate))) } DataType::Utf8 => { - Ok(Arc::new(filter_bytes(as_string_array(values), predicate))) + Ok(Arc::new(filter_bytes(values.as_string::(), predicate))) } DataType::LargeUtf8 => { - Ok(Arc::new(filter_bytes(as_largestring_array(values), predicate))) + Ok(Arc::new(filter_bytes(values.as_string::(), predicate))) } DataType::Binary => { - Ok(Arc::new(filter_bytes(as_generic_binary_array::(values), predicate))) + Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } DataType::LargeBinary => { - Ok(Arc::new(filter_bytes(as_generic_binary_array::(values), predicate))) + Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } DataType::Dictionary(_, _) => downcast_dictionary_array! { values => Ok(Arc::new(filter_dict(values, predicate))), diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index 95b694aba732..f274a3ebc30f 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -225,7 +225,7 @@ fn interleave_fallback( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder}; - use arrow_array::cast::{as_primitive_array, as_string_array}; + use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; use arrow_array::{Int32Array, ListArray, StringArray}; use arrow_schema::DataType; @@ -237,7 +237,7 @@ mod tests { let c = Int32Array::from_iter_values([8, 9, 10]); let values = interleave(&[&a, &b, &c], &[(0, 3), (0, 3), (2, 2), (2, 0), (1, 1)]).unwrap(); - let v = as_primitive_array::(&values); + let v = values.as_primitive::(); assert_eq!(v.values(), &[4, 4, 10, 8, 6]); } @@ -247,9 +247,7 @@ mod tests { let b = Int32Array::from_iter([Some(1), Some(4), None]); let values = interleave(&[&a, &b], &[(0, 1), (1, 2), (1, 2), (0, 3), (0, 2)]).unwrap(); - let v: Vec<_> = as_primitive_array::(&values) - .into_iter() - .collect(); + let v: Vec<_> = values.as_primitive::().into_iter().collect(); assert_eq!(&v, &[Some(2), None, None, Some(4), Some(3)]) } @@ -267,7 +265,7 @@ mod tests { let b = StringArray::from_iter_values(["hello", "world", "foo"]); let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]).unwrap(); - let v = as_string_array(&values); + let v = values.as_string::(); let values: Vec<_> = v.into_iter().collect(); assert_eq!( &values, diff --git a/arrow-select/src/nullif.rs b/arrow-select/src/nullif.rs index ea0c8e3d526c..a1b9c0e3e183 100644 --- a/arrow-select/src/nullif.rs +++ b/arrow-select/src/nullif.rs @@ -124,7 +124,7 @@ pub fn nullif(left: &dyn Array, right: &BooleanArray) -> Result(&res); + let res = res.as_primitive::(); assert_eq!(&expected, res); } @@ -175,7 +175,7 @@ mod tests { Some(8), // None => keep it None, // true => None ]); - let res = as_primitive_array::(&res); + let res = res.as_primitive::(); assert_eq!(&expected, res) } @@ -201,7 +201,7 @@ mod tests { ]); let a = nullif(&s, &select).unwrap(); - let r: Vec<_> = as_string_array(&a).iter().collect(); + let r: Vec<_> = a.as_string::().iter().collect(); assert_eq!( r, vec![None, None, Some("world"), None, Some("b"), None, None] @@ -209,9 +209,9 @@ mod tests { let s = s.slice(2, 3); let select = select.slice(1, 3); - let select = as_boolean_array(select.as_ref()); + let select = select.as_boolean(); let a = nullif(s.as_ref(), select).unwrap(); - let r: Vec<_> = as_string_array(&a).iter().collect(); + let r: Vec<_> = a.as_string::().iter().collect(); assert_eq!(r, vec![None, Some("a"), None]); } @@ -456,7 +456,7 @@ mod tests { let comp = BooleanArray::from(vec![Some(false), None, Some(true), Some(false), None]); let res = nullif(&a, &comp).unwrap(); - let res = as_primitive_array::(res.as_ref()); + let res = res.as_primitive::(); let expected = Int32Array::from(vec![Some(15), Some(7), None, Some(1), Some(9)]); assert_eq!(res, &expected); @@ -500,7 +500,7 @@ mod tests { for (a_offset, a_length) in a_slices { let a = a.slice(a_offset, a_length); - let a = as_primitive_array::(a.as_ref()); + let a = a.as_primitive::(); for i in 1..65 { let b_start_offset = rng.gen_range(0..i); @@ -510,7 +510,7 @@ mod tests { .map(|_| rng.gen_bool(0.5).then(|| rng.gen_bool(0.5))) .collect(); let b = b.slice(b_start_offset, a_length); - let b = as_boolean_array(b.as_ref()); + let b = b.as_boolean(); test_nullif(a, b); } diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 421157bdf041..83b58519fdb8 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -20,13 +20,13 @@ use std::sync::Arc; use arrow_array::builder::BufferBuilder; +use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; -use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array}; use num::{ToPrimitive, Zero}; /// Take elements by index from [Array], creating a new [Array] from those indexes. @@ -128,24 +128,16 @@ where Ok(Arc::new(take_boolean(values, indices)?)) } DataType::Utf8 => { - Ok(Arc::new(take_bytes(as_string_array(values), indices)?)) + Ok(Arc::new(take_bytes(values.as_string::(), indices)?)) } DataType::LargeUtf8 => { - Ok(Arc::new(take_bytes(as_largestring_array(values), indices)?)) + Ok(Arc::new(take_bytes(values.as_string::(), indices)?)) } DataType::List(_) => { - let values = values - .as_any() - .downcast_ref::>() - .unwrap(); - Ok(Arc::new(take_list::<_, Int32Type>(values, indices)?)) + Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?)) } DataType::LargeList(_) => { - let values = values - .as_any() - .downcast_ref::>() - .unwrap(); - Ok(Arc::new(take_list::<_, Int64Type>(values, indices)?)) + Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?)) } DataType::FixedSizeList(_, length) => { let values = values @@ -193,10 +185,10 @@ where t => unimplemented!("Take not supported for run type {:?}", t) } DataType::Binary => { - Ok(Arc::new(take_bytes(as_generic_binary_array::(values), indices)?)) + Ok(Arc::new(take_bytes(values.as_binary::(), indices)?)) } DataType::LargeBinary => { - Ok(Arc::new(take_bytes(as_generic_binary_array::(values), indices)?)) + Ok(Arc::new(take_bytes(values.as_binary::(), indices)?)) } DataType::FixedSizeBinary(size) => { let values = values @@ -969,7 +961,7 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::{builder::*, cast::as_primitive_array}; + use arrow_array::builder::*; use arrow_schema::TimeUnit; fn test_take_decimal_arrays( @@ -2160,7 +2152,7 @@ mod tests { assert_eq!(take_out.run_ends().len(), 7); assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]); - let take_out_values = as_primitive_array::(take_out.values()); + let take_out_values = take_out.values().as_primitive::(); assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]); } diff --git a/arrow-string/src/length.rs b/arrow-string/src/length.rs index a48fc13409f1..f0c09a7ec4d8 100644 --- a/arrow-string/src/length.rs +++ b/arrow-string/src/length.rs @@ -214,7 +214,7 @@ pub fn bit_length(array: &dyn Array) -> Result { #[cfg(test)] mod tests { use super::*; - use arrow_array::cast::as_primitive_array; + use arrow_array::cast::AsArray; fn double_vec(v: Vec) -> Vec { [&v[..], &v[..]].concat() @@ -427,7 +427,7 @@ mod tests { let a = StringArray::from(vec![Some("hello"), Some(" "), Some("world"), None]); let b = a.slice(1, 3); let result = length(b.as_ref()).unwrap(); - let result: &Int32Array = as_primitive_array(&result); + let result: &Int32Array = result.as_primitive(); let expected = Int32Array::from(vec![Some(1), Some(5), None]); assert_eq!(&expected, result); @@ -440,7 +440,7 @@ mod tests { let a = BinaryArray::from(value); let b = a.slice(1, 3); let result = length(b.as_ref()).unwrap(); - let result: &Int32Array = as_primitive_array(&result); + let result: &Int32Array = result.as_primitive(); let expected = Int32Array::from(vec![Some(1), Some(2), None]); assert_eq!(&expected, result); @@ -582,7 +582,7 @@ mod tests { let a = StringArray::from(vec![Some("hello"), Some(" "), Some("world"), None]); let b = a.slice(1, 3); let result = bit_length(b.as_ref()).unwrap(); - let result: &Int32Array = as_primitive_array(&result); + let result: &Int32Array = result.as_primitive(); let expected = Int32Array::from(vec![Some(8), Some(40), None]); assert_eq!(&expected, result); @@ -595,7 +595,7 @@ mod tests { let a = BinaryArray::from(value); let b = a.slice(1, 3); let result = bit_length(b.as_ref()).unwrap(); - let result: &Int32Array = as_primitive_array(&result); + let result: &Int32Array = result.as_primitive(); let expected = Int32Array::from(vec![Some(0), Some(40), None]); assert_eq!(&expected, result); diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index e8ec699969bd..7b6c7d50cac3 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -71,13 +71,13 @@ macro_rules! dyn_function { pub fn $fn_name(left: &dyn Array, right: &dyn Array) -> Result { match (left.data_type(), right.data_type()) { (DataType::Utf8, DataType::Utf8) => { - let left = as_string_array(left); - let right = as_string_array(right); + let left = left.as_string::(); + let right = right.as_string::(); $fn_utf8(left, right) } (DataType::LargeUtf8, DataType::LargeUtf8) => { - let left = as_largestring_array(left); - let right = as_largestring_array(right); + let left = left.as_string::(); + let right = right.as_string::(); $fn_utf8(left, right) } #[cfg(feature = "dyn_cmp_dict")] @@ -139,11 +139,11 @@ pub fn $fn_name( ) -> Result { match left.data_type() { DataType::Utf8 => { - let left = as_string_array(left); + let left = left.as_string::(); $fn_scalar(left, right) } DataType::LargeUtf8 => { - let left = as_largestring_array(left); + let left = left.as_string::(); $fn_scalar(left, right) } DataType::Dictionary(_, _) => { diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index 4b1251ebcd2b..40b09a976178 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -135,6 +135,25 @@ //! } //! ``` //! +//! To facilitate downcasting, the [`AsArray`](crate::array::AsArray) extension trait can be used +//! +//! ```rust +//! # use arrow::array::{Array, Float32Array, AsArray}; +//! # use arrow::array::StringArray; +//! # use arrow::datatypes::DataType; +//! # +//! fn impl_string(array: &StringArray) {} +//! fn impl_f32(array: &Float32Array) {} +//! +//! fn impl_dyn(array: &dyn Array) { +//! match array.data_type() { +//! DataType::Utf8 => impl_string(array.as_string()), +//! DataType::Float32 => impl_f32(array.as_primitive()), +//! _ => unimplemented!() +//! } +//! } +//! ``` +//! //! It is also common to want to write a function that returns one of a number of possible //! array implementations. [`ArrayRef`] is a type-alias for [`Arc`](array::Array) //! which is frequently used for this purpose @@ -207,7 +226,7 @@ //! //! ``` //! # use arrow::compute::gt_scalar; -//! # use arrow_array::cast::as_primitive_array; +//! # use arrow_array::cast::AsArray; //! # use arrow_array::Int32Array; //! # use arrow_array::types::Int32Type; //! # use arrow_select::filter::filter; @@ -216,7 +235,7 @@ //! let filtered = filter(&array, &predicate).unwrap(); //! //! let expected = Int32Array::from_iter(61..100); -//! assert_eq!(&expected, as_primitive_array::(&filtered)); +//! assert_eq!(&expected, filtered.as_primitive::()); //! ``` //! //! As well as some horizontal operations, such as: diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 7ead5fa61522..0956893a870d 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -335,10 +335,8 @@ mod tests { let col_c = struct_array.column_by_name("c").unwrap(); let col_c = col_c.as_any().downcast_ref::().unwrap(); assert_eq!(col_c.len(), size); - let col_c_values = col_c.values(); - assert!(col_c_values.len() > size); - // col_c_values should be a list - let col_c_list = as_list_array(col_c_values); + let col_c_list = col_c.values().as_list::(); + assert!(col_c_list.len() > size); // Its values should be FixedSizeBinary(6) let fsb = col_c_list.values(); assert_eq!(fsb.data_type(), &DataType::FixedSizeBinary(6)); diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index e6693a6cff4a..2d867c9596c7 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -21,7 +21,7 @@ use std::collections::VecDeque; use std::io::Write; use std::sync::Arc; -use arrow_array::cast::as_primitive_array; +use arrow_array::cast::AsArray; use arrow_array::types::Decimal128Type; use arrow_array::{types, Array, ArrayRef, RecordBatch}; use arrow_schema::{DataType as ArrowDataType, IntervalUnit, SchemaRef}; @@ -400,7 +400,8 @@ fn write_leaf( } ArrowDataType::Decimal128(_, _) => { // use the int32 to represent the decimal with low precision - let array = as_primitive_array::(column) + let array = column + .as_primitive::() .unary::<_, types::Int32Type>(|v| v as i32); write_primitive(typed, array.values(), levels)? } @@ -444,7 +445,8 @@ fn write_leaf( } ArrowDataType::Decimal128(_, _) => { // use the int64 to represent the decimal with low precision - let array = as_primitive_array::(column) + let array = column + .as_primitive::() .unary::<_, types::Int64Type>(|v| v as i64); write_primitive(typed, array.values(), levels)? } diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 99fe650695a0..2d39284c763f 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -840,7 +840,7 @@ mod tests { use crate::file::page_index::index_reader; use crate::file::properties::WriterProperties; use arrow::error::Result as ArrowResult; - use arrow_array::cast::as_primitive_array; + use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; use arrow_array::{Array, ArrayRef, Int32Array, StringArray}; use futures::TryStreamExt; @@ -1355,14 +1355,14 @@ mod tests { // First batch should contain all rows assert_eq!(batch.num_rows(), 3); assert_eq!(batch.num_columns(), 3); - let col2 = as_primitive_array::(batch.column(2)); + let col2 = batch.column(2).as_primitive::(); assert_eq!(col2.values(), &[0, 1, 2]); let batch = &batches[1]; // Second batch should trigger the limit and only have one row assert_eq!(batch.num_rows(), 1); assert_eq!(batch.num_columns(), 3); - let col2 = as_primitive_array::(batch.column(2)); + let col2 = batch.column(2).as_primitive::(); assert_eq!(col2.values(), &[3]); let stream = ParquetRecordBatchStreamBuilder::new(test.clone()) @@ -1381,14 +1381,14 @@ mod tests { // First batch should contain one row assert_eq!(batch.num_rows(), 1); assert_eq!(batch.num_columns(), 3); - let col2 = as_primitive_array::(batch.column(2)); + let col2 = batch.column(2).as_primitive::(); assert_eq!(col2.values(), &[2]); let batch = &batches[1]; // Second batch should contain two rows assert_eq!(batch.num_rows(), 2); assert_eq!(batch.num_columns(), 3); - let col2 = as_primitive_array::(batch.column(2)); + let col2 = batch.column(2).as_primitive::(); assert_eq!(col2.values(), &[3, 4]); let stream = ParquetRecordBatchStreamBuilder::new(test.clone()) @@ -1407,7 +1407,7 @@ mod tests { // First batch should contain two rows assert_eq!(batch.num_rows(), 2); assert_eq!(batch.num_columns(), 3); - let col2 = as_primitive_array::(batch.column(2)); + let col2 = batch.column(2).as_primitive::(); assert_eq!(col2.values(), &[4, 5]); }