diff --git a/crates/polars-arrow/src/array/fixed_size_list/data.rs b/crates/polars-arrow/src/array/fixed_size_list/data.rs index de9bc1b882c2..c1f353db691a 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/data.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/data.rs @@ -18,6 +18,7 @@ impl Arrow2Arrow for FixedSizeListArray { fn from_data(data: &ArrayData) -> Self { let dtype: ArrowDataType = data.data_type().clone().into(); + let length = data.len() - data.offset(); let size = match dtype { ArrowDataType::FixedSizeList(_, size) => size, _ => unreachable!("must be FixedSizeList type"), @@ -28,6 +29,7 @@ impl Arrow2Arrow for FixedSizeListArray { Self { size, + length, dtype, values, validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), diff --git a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs index 29cf7957cf6c..297d7ae8e5f2 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs @@ -1,4 +1,4 @@ -use polars_error::PolarsResult; +use polars_error::{polars_ensure, PolarsResult}; use super::FixedSizeListArray; use crate::array::ffi::{FromFfi, ToFfi}; @@ -31,11 +31,19 @@ unsafe impl ToFfi for FixedSizeListArray { impl FromFfi for FixedSizeListArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { let dtype = array.dtype().clone(); + let (_, width) = FixedSizeListArray::try_child_and_size(&dtype)?; let validity = unsafe { array.validity() }?; - let child = unsafe { array.child(0)? }; + let child = unsafe { array.child(0) }?; let values = ffi::try_from(child)?; - let mut fsl = Self::try_new(dtype, values, validity)?; + let length = if values.len() == 0 { + 0 + } else { + polars_ensure!(width > 0, InvalidOperation: "Zero-width array with values"); + values.len() / width + }; + + let mut fsl = Self::try_new(dtype, length, values, validity)?; fsl.slice(array.offset(), array.length()); Ok(fsl) } diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 37cc7e2ad781..4f1622819813 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -10,7 +10,7 @@ mod iterator; mod mutable; pub use mutable::*; -use polars_error::{polars_bail, PolarsResult}; +use polars_error::{polars_bail, polars_ensure, PolarsResult}; use polars_utils::pl_str::PlSmallStr; /// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. @@ -18,6 +18,7 @@ use polars_utils::pl_str::PlSmallStr; #[derive(Clone)] pub struct FixedSizeListArray { size: usize, // this is redundant with `dtype`, but useful to not have to deconstruct the dtype. + length: usize, // invariant: this is values.len() / size if size > 0 dtype: ArrowDataType, values: Box, validity: Option, @@ -34,6 +35,7 @@ impl FixedSizeListArray { /// * the validity's length is not equal to `values.len() / size`. pub fn try_new( dtype: ArrowDataType, + length: usize, values: Box, validity: Option, ) -> PolarsResult { @@ -45,34 +47,61 @@ impl FixedSizeListArray { polars_bail!(ComputeError: "FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_dtype:?} while it got {values_dtype:?}.") } - if values.len() % size != 0 { - polars_bail!(ComputeError: - "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", - values.len(), - size - ) - } - let len = values.len() / size; + polars_ensure!(size == 0 || values.len() % size == 0, ComputeError: + "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", + values.len(), + size + ); + + polars_ensure!(size == 0 || values.len() / size == length, ComputeError: + "length of values ({}) is not equal to given length ({}) in FixedSizeListArray({size}).", + values.len() / size, + length, + ); + polars_ensure!(size != 0 || values.len() == 0, ComputeError: + "zero width FixedSizeListArray has values (length = {}).", + values.len(), + ); if validity .as_ref() - .map_or(false, |validity| validity.len() != len) + .map_or(false, |validity| validity.len() != length) { polars_bail!(ComputeError: "validity mask length must be equal to the number of values divided by size") } Ok(Self { size, + length, dtype, values, validity, }) } + #[inline] + fn has_invariants(&self) -> bool { + let has_valid_length = (self.size == 0 && self.values().len() == 0) + || (self.size > 0 + && self.values().len() % self.size() == 0 + && self.values().len() / self.size() == self.length); + let has_valid_validity = self + .validity + .as_ref() + .map_or(true, |v| v.len() == self.length); + + has_valid_length && has_valid_validity + } + /// Alias to `Self::try_new(...).unwrap()` #[track_caller] - pub fn new(dtype: ArrowDataType, values: Box, validity: Option) -> Self { - Self::try_new(dtype, values, validity).unwrap() + pub fn new( + dtype: ArrowDataType, + length: usize, + values: Box, + validity: Option, + ) -> Self { + Self::try_new(dtype, length, values, validity).unwrap() } /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. @@ -83,7 +112,7 @@ impl FixedSizeListArray { /// Returns a new empty [`FixedSizeListArray`]. pub fn new_empty(dtype: ArrowDataType) -> Self { let values = new_empty_array(Self::get_child_and_size(&dtype).0.dtype().clone()); - Self::new(dtype, values, None) + Self::new(dtype, 0, values, None) } /// Returns a new null [`FixedSizeListArray`]. @@ -91,7 +120,7 @@ impl FixedSizeListArray { let (field, size) = Self::get_child_and_size(&dtype); let values = new_null_array(field.dtype().clone(), length * size); - Self::new(dtype, values, Some(Bitmap::new_zeroed(length))) + Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) } } @@ -124,6 +153,7 @@ impl FixedSizeListArray { .filter(|bitmap| bitmap.unset_bits() > 0); self.values .slice_unchecked(offset * self.size, length * self.size); + self.length = length; } impl_sliced!(); @@ -136,7 +166,8 @@ impl FixedSizeListArray { /// Returns the length of this array #[inline] pub fn len(&self) -> usize { - self.values.len() / self.size + debug_assert!(self.has_invariants()); + self.length } /// The optional validity. @@ -184,12 +215,7 @@ impl FixedSizeListArray { impl FixedSizeListArray { pub(crate) fn try_child_and_size(dtype: &ArrowDataType) -> PolarsResult<(&Field, usize)> { match dtype.to_logical_type() { - ArrowDataType::FixedSizeList(child, size) => { - if *size == 0 { - polars_bail!(ComputeError: "FixedSizeBinaryArray expects a positive size") - } - Ok((child.as_ref(), *size)) - }, + ArrowDataType::FixedSizeList(child, size) => Ok((child.as_ref(), *size)), _ => polars_bail!(ComputeError: "FixedSizeListArray expects DataType::FixedSizeList"), } } @@ -233,12 +259,14 @@ impl Splitable for FixedSizeListArray { ( Self { dtype: self.dtype.clone(), + length: offset, values: lhs_values, validity: lhs_validity, size, }, Self { dtype: self.dtype.clone(), + length: self.length - offset, values: rhs_values, validity: rhs_validity, size, diff --git a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs index 04802e59bd67..b3a32be0802c 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs @@ -14,6 +14,7 @@ use crate::datatypes::{ArrowDataType, Field}; pub struct MutableFixedSizeListArray { dtype: ArrowDataType, size: usize, + length: usize, values: M, validity: Option, } @@ -22,6 +23,7 @@ impl From> for FixedSizeListArray fn from(mut other: MutableFixedSizeListArray) -> Self { FixedSizeListArray::new( other.dtype, + other.length, other.values.as_box(), other.validity.map(|x| x.into()), ) @@ -53,12 +55,19 @@ impl MutableFixedSizeListArray { }; Self { size, + length: 0, dtype, values, validity: None, } } + #[inline] + fn has_valid_invariants(&self) -> bool { + (self.size == 0 && self.values().len() == 0) + || (self.size > 0 && self.values.len() / self.size == self.length) + } + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. pub const fn size(&self) -> usize { self.size @@ -66,7 +75,8 @@ impl MutableFixedSizeListArray { /// The length of this array pub fn len(&self) -> usize { - self.values.len() / self.size + debug_assert!(self.has_valid_invariants()); + self.length } /// The inner values @@ -74,11 +84,6 @@ impl MutableFixedSizeListArray { &self.values } - /// The values as a mutable reference - pub fn mut_values(&mut self) -> &mut M { - &mut self.values - } - fn init_validity(&mut self) { let len = self.values.len() / self.size; @@ -98,6 +103,10 @@ impl MutableFixedSizeListArray { if let Some(validity) = &mut self.validity { validity.push(true) } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); + Ok(()) } @@ -108,6 +117,9 @@ impl MutableFixedSizeListArray { if let Some(validity) = &mut self.validity { validity.push(true) } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); } #[inline] @@ -117,6 +129,9 @@ impl MutableFixedSizeListArray { Some(validity) => validity.push(false), None => self.init_validity(), } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); } /// Reserves `additional` slots. @@ -138,7 +153,8 @@ impl MutableFixedSizeListArray { impl MutableArray for MutableFixedSizeListArray { fn len(&self) -> usize { - self.values.len() / self.size + debug_assert!(self.has_valid_invariants()); + self.length } fn validity(&self) -> Option<&MutableBitmap> { @@ -148,6 +164,7 @@ impl MutableArray for MutableFixedSizeListArray { fn as_box(&mut self) -> Box { FixedSizeListArray::new( self.dtype.clone(), + self.length, self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), ) @@ -157,6 +174,7 @@ impl MutableArray for MutableFixedSizeListArray { fn as_arc(&mut self) -> Arc { FixedSizeListArray::new( self.dtype.clone(), + self.length, self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), ) @@ -185,6 +203,9 @@ impl MutableArray for MutableFixedSizeListArray { } else { self.init_validity() } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); } fn reserve(&mut self, additional: usize) { @@ -206,6 +227,9 @@ where for items in iter { self.try_push(items)?; } + + debug_assert!(self.has_valid_invariants()); + Ok(()) } } @@ -223,6 +247,9 @@ where } else { self.push_null(); } + + debug_assert!(self.has_valid_invariants()); + Ok(()) } } @@ -243,6 +270,8 @@ where } else { self.push_null(); } + + debug_assert!(self.has_valid_invariants()); } } @@ -253,6 +282,11 @@ where fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { extend_validity(self.len(), &mut self.validity, &other.validity); - self.values.try_extend_from_self(&other.values) + self.values.try_extend_from_self(&other.values)?; + self.length += other.len(); + + debug_assert!(self.has_valid_invariants()); + + Ok(()) } } diff --git a/crates/polars-arrow/src/array/growable/fixed_size_list.rs b/crates/polars-arrow/src/array/growable/fixed_size_list.rs index c15202084006..5fedb9a4d254 100644 --- a/crates/polars-arrow/src/array/growable/fixed_size_list.rs +++ b/crates/polars-arrow/src/array/growable/fixed_size_list.rs @@ -6,7 +6,6 @@ use super::{make_growable, Growable}; use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity}; use crate::array::{Array, FixedSizeListArray}; use crate::bitmap::MutableBitmap; -use crate::datatypes::ArrowDataType; /// Concrete [`Growable`] for the [`FixedSizeListArray`]. pub struct GrowableFixedSizeList<'a> { @@ -14,6 +13,7 @@ pub struct GrowableFixedSizeList<'a> { validity: Option, values: Box + 'a>, size: usize, + length: usize, } impl<'a> GrowableFixedSizeList<'a> { @@ -33,24 +33,25 @@ impl<'a> GrowableFixedSizeList<'a> { use_validity = true; }; - let size = - if let ArrowDataType::FixedSizeList(_, size) = &arrays[0].dtype().to_logical_type() { - *size - } else { - unreachable!("`GrowableFixedSizeList` expects `DataType::FixedSizeList`") - }; + let size = arrays[0].size(); let inner = arrays .iter() - .map(|array| array.values().as_ref()) + .map(|array| { + debug_assert_eq!(array.size(), size); + array.values().as_ref() + }) .collect::>(); let values = make_growable(&inner, use_validity, 0); + assert_eq!(values.len(), 0); + Self { arrays, values, validity: prepare_validity(use_validity, capacity), size, + length: 0, } } @@ -60,6 +61,7 @@ impl<'a> GrowableFixedSizeList<'a> { FixedSizeListArray::new( self.arrays[0].dtype().clone(), + self.length, values, validity.map(|v| v.into()), ) @@ -71,16 +73,24 @@ impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { let array = *self.arrays.get_unchecked_release(index); extend_validity(&mut self.validity, array, start, len); + self.length += len; + let start_length = self.values.len(); self.values .extend(index, start * self.size, len * self.size); + debug_assert!(self.size == 0 || (self.values.len() - start_length) / self.size == len); } unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) { let array = *self.arrays.get_unchecked_release(index); extend_validity_copies(&mut self.validity, array, start, len, copies); + self.length += len * copies; + let start_length = self.values.len(); self.values .extend_copies(index, start * self.size, len * self.size, copies); + debug_assert!( + self.size == 0 || (self.values.len() - start_length) / self.size == len * copies + ); } fn extend_validity(&mut self, additional: usize) { @@ -88,11 +98,12 @@ impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { if let Some(validity) = &mut self.validity { validity.extend_constant(additional, false); } + self.length += additional; } #[inline] fn len(&self) -> usize { - self.values.len() / self.size + self.length } fn as_arc(&mut self) -> Arc { @@ -111,6 +122,7 @@ impl<'a> From> for FixedSizeListArray { Self::new( val.arrays[0].dtype().clone(), + val.length, values, val.validity.map(|v| v.into()), ) diff --git a/crates/polars-arrow/src/array/static_array.rs b/crates/polars-arrow/src/array/static_array.rs index 3cfbc870e141..a79b6a909fe6 100644 --- a/crates/polars-arrow/src/array/static_array.rs +++ b/crates/polars-arrow/src/array/static_array.rs @@ -1,8 +1,8 @@ use bytemuck::Zeroable; use polars_utils::no_call_const; +use super::growable::{Growable, GrowableFixedSizeList}; use crate::array::binview::BinaryViewValueIter; -use crate::array::growable::{Growable, GrowableFixedSizeList}; use crate::array::static_array_collect::ArrayFromIterDtype; use crate::array::{ Array, ArrayValuesIter, BinaryArray, BinaryValueIter, BinaryViewArray, BooleanArray, @@ -394,7 +394,7 @@ impl StaticArray for FixedSizeListArray { } fn full(length: usize, value: Self::ValueT<'_>, dtype: ArrowDataType) -> Self { - let singular_arr = FixedSizeListArray::new(dtype, value, None); + let singular_arr = FixedSizeListArray::new(dtype, 1, value, None); let mut arr = GrowableFixedSizeList::new(vec![&singular_arr], false, length); unsafe { arr.extend_copies(0, 0, 1, length) } arr.into() diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 4dd8857b95c3..27f93eb07356 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -174,25 +174,20 @@ fn cast_list_to_fixed_size_list( ) -> PolarsResult { let null_cnt = list.null_count(); let new_values = if null_cnt == 0 { - let offsets = list.offsets().buffer().iter(); - let expected = - (list.offsets().first().to_usize()..list.len()).map(|ix| O::from_as_usize(ix * size)); - - match offsets - .zip(expected) - .find(|(actual, expected)| *actual != expected) - { - Some(_) => polars_bail!(ComputeError: - "not all elements have the specified width {size}" - ), - None => { - let sliced_values = list.values().sliced( - list.offsets().first().to_usize(), - list.offsets().range().to_usize(), - ); - cast(sliced_values.as_ref(), inner.dtype(), options)? - }, + let start_offset = list.offsets().first().to_usize(); + let offsets = list.offsets().buffer(); + + let mut is_valid = true; + for (i, offset) in offsets.iter().enumerate() { + is_valid &= offset.to_usize() == start_offset + i * size; } + + polars_ensure!(is_valid, ComputeError: "not all elements have the specified width {size}"); + + let sliced_values = list + .values() + .sliced(start_offset, list.offsets().range().to_usize()); + cast(sliced_values.as_ref(), inner.dtype(), options)? } else { let offsets = list.offsets().as_slice(); // Check the lengths of each list are equal to the fixed size. @@ -232,8 +227,10 @@ fn cast_list_to_fixed_size_list( cast(take_values.as_ref(), inner.dtype(), options)? }; + FixedSizeListArray::try_new( ArrowDataType::FixedSizeList(Box::new(inner.clone()), size), + list.len(), new_values, list.validity().cloned(), ) diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 6ef9687f146e..8f2226c709e6 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -567,6 +567,17 @@ impl ArrowDataType { pub fn is_view(&self) -> bool { matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) } + + pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType { + ArrowDataType::FixedSizeList( + Box::new(Field::new( + PlSmallStr::from_static("item"), + self, + is_nullable, + )), + size, + ) + } } impl From for ArrowDataType { diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs index eac68f9fda54..fdfa13574e3d 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -1,7 +1,7 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; -use polars_error::{polars_err, PolarsResult}; +use polars_error::{polars_ensure, polars_err, PolarsResult}; use super::super::super::IpcField; use super::super::deserialize::{read, skip}; @@ -41,6 +41,7 @@ pub fn read_fixed_size_list( )?; let (field, size) = FixedSizeListArray::get_child_and_size(&dtype); + polars_ensure!(size > 0, nyi = "Cannot read zero sized arrays from IPC"); let limit = limit.map(|x| x.saturating_mul(size)); @@ -59,7 +60,7 @@ pub fn read_fixed_size_list( version, scratch, )?; - FixedSizeListArray::try_new(dtype, values, validity) + FixedSizeListArray::try_new(dtype, values.len() / size, values, validity) } pub fn skip_fixed_size_list( diff --git a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs index 99382b0b6407..813f357e2137 100644 --- a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs @@ -10,6 +10,7 @@ use crate::legacy::kernels::concatenate::concatenate_owned_unchecked; pub struct AnonymousBuilder { arrays: Vec, validity: Option, + length: usize, pub width: usize, } @@ -19,6 +20,7 @@ impl AnonymousBuilder { arrays: Vec::with_capacity(capacity), validity: None, width, + length: 0, } } pub fn is_empty(&self) -> bool { @@ -32,6 +34,8 @@ impl AnonymousBuilder { if let Some(validity) = &mut self.validity { validity.push(true) } + + self.length += 1; } pub fn push_null(&mut self) { @@ -41,6 +45,8 @@ impl AnonymousBuilder { Some(validity) => validity.push(false), None => self.init_validity(), } + + self.length += 1; } fn init_validity(&mut self) { @@ -82,6 +88,7 @@ impl AnonymousBuilder { let dtype = FixedSizeListArray::default_datatype(inner_dtype.clone(), self.width); Ok(FixedSizeListArray::new( dtype, + self.length, values, self.validity.map(|validity| validity.into()), )) diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs index bbb876283470..f15ac1811f96 100644 --- a/crates/polars-arrow/src/legacy/array/mod.rs +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -212,11 +212,23 @@ pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { + let width = *width; + let array = array.as_any().downcast_ref::().unwrap(); let inner = array.values(); + let length = if width == array.size() { + array.len() + } else { + assert!(array.values().len() > 0 || width != 0); + if width == 0 { + 0 + } else { + array.values().len() / width + } + }; let new_values = convert_inner_type(inner.as_ref(), field.dtype()); - let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), *width); - FixedSizeListArray::new(dtype, new_values, array.validity().cloned()).boxed() + let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), width); + FixedSizeListArray::new(dtype, length, new_values, array.validity().cloned()).boxed() }, ArrowDataType::Struct(fields) => { let array = array.as_any().downcast_ref::().unwrap(); diff --git a/crates/polars-compute/src/comparisons/array.rs b/crates/polars-compute/src/comparisons/array.rs index da120f27553b..23d43887a280 100644 --- a/crates/polars-compute/src/comparisons/array.rs +++ b/crates/polars-compute/src/comparisons/array.rs @@ -48,6 +48,10 @@ impl TotalEqKernel for FixedSizeListArray { return Bitmap::new_with_value(false, self.len()); } + if *self_width == 0 { + return Bitmap::new_with_value(true, self.len()); + } + let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref()); agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0) @@ -69,6 +73,10 @@ impl TotalEqKernel for FixedSizeListArray { return Bitmap::new_with_value(true, self.len()); } + if *self_width == 0 { + return Bitmap::new_with_value(false, self.len()); + } + let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref()); agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size()) diff --git a/crates/polars-core/src/chunked_array/array/mod.rs b/crates/polars-core/src/chunked_array/array/mod.rs index 59bdd92b67cc..3e0e47a7e86a 100644 --- a/crates/polars-core/src/chunked_array/array/mod.rs +++ b/crates/polars-core/src/chunked_array/array/mod.rs @@ -74,7 +74,8 @@ impl ArrayChunked { out.dtype().to_arrow(CompatLevel::newest()), ca.width(), ); - let arr = FixedSizeListArray::new(inner_dtype, values, arr.validity().cloned()); + let arr = + FixedSizeListArray::new(inner_dtype, arr.len(), values, arr.validity().cloned()); Ok(arr) }); diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index ea758742169e..1b8228d4ea69 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -664,7 +664,7 @@ fn cast_fixed_size_list( let new_values = new_inner.array_ref(0).clone(); let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), ca.width()); - let new_arr = FixedSizeListArray::new(dtype, new_values, arr.validity().cloned()); + let new_arr = FixedSizeListArray::new(dtype, ca.len(), new_values, arr.validity().cloned()); Ok((Box::new(new_arr), inner_dtype)) } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index bf5c748eeed1..33e984b94e0f 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -71,6 +71,7 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy let arrow_dtype = FixedSizeListArray::default_datatype(ArrowDataType::UInt32, width); let new_array = FixedSizeListArray::new( arrow_dtype, + values_arr.len(), cat.array_ref(0).clone(), list_arr.validity().cloned(), ); diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 956d055a52c2..cd79349bfcd8 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -591,10 +591,9 @@ impl DataType { Duration(unit) => Ok(ArrowDataType::Duration(unit.to_arrow())), Time => Ok(ArrowDataType::Time64(ArrowTimeUnit::Nanosecond)), #[cfg(feature = "dtype-array")] - Array(dt, size) => Ok(ArrowDataType::FixedSizeList( - Box::new(dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level)), - *size, - )), + Array(dt, size) => Ok(dt + .try_to_arrow(compat_level)? + .to_fixed_size_list(*size, true)), List(dt) => Ok(ArrowDataType::LargeList(Box::new( dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level), ))), diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 727faf0768c8..1dea44ee393a 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -189,6 +189,84 @@ impl Column { } } + // # Try to Chunked Arrays + pub fn try_bool(&self) -> Option<&BooleanChunked> { + self.as_materialized_series().try_bool() + } + pub fn try_i8(&self) -> Option<&Int8Chunked> { + self.as_materialized_series().try_i8() + } + pub fn try_i16(&self) -> Option<&Int16Chunked> { + self.as_materialized_series().try_i16() + } + pub fn try_i32(&self) -> Option<&Int32Chunked> { + self.as_materialized_series().try_i32() + } + pub fn try_i64(&self) -> Option<&Int64Chunked> { + self.as_materialized_series().try_i64() + } + pub fn try_u8(&self) -> Option<&UInt8Chunked> { + self.as_materialized_series().try_u8() + } + pub fn try_u16(&self) -> Option<&UInt16Chunked> { + self.as_materialized_series().try_u16() + } + pub fn try_u32(&self) -> Option<&UInt32Chunked> { + self.as_materialized_series().try_u32() + } + pub fn try_u64(&self) -> Option<&UInt64Chunked> { + self.as_materialized_series().try_u64() + } + pub fn try_f32(&self) -> Option<&Float32Chunked> { + self.as_materialized_series().try_f32() + } + pub fn try_f64(&self) -> Option<&Float64Chunked> { + self.as_materialized_series().try_f64() + } + pub fn try_str(&self) -> Option<&StringChunked> { + self.as_materialized_series().try_str() + } + pub fn try_list(&self) -> Option<&ListChunked> { + self.as_materialized_series().try_list() + } + pub fn try_binary(&self) -> Option<&BinaryChunked> { + self.as_materialized_series().try_binary() + } + pub fn try_idx(&self) -> Option<&IdxCa> { + self.as_materialized_series().try_idx() + } + pub fn try_binary_offset(&self) -> Option<&BinaryOffsetChunked> { + self.as_materialized_series().try_binary_offset() + } + #[cfg(feature = "dtype-datetime")] + pub fn try_datetime(&self) -> Option<&DatetimeChunked> { + self.as_materialized_series().try_datetime() + } + #[cfg(feature = "dtype-struct")] + pub fn try_struct(&self) -> Option<&StructChunked> { + self.as_materialized_series().try_struct() + } + #[cfg(feature = "dtype-decimal")] + pub fn try_decimal(&self) -> Option<&DecimalChunked> { + self.as_materialized_series().try_decimal() + } + #[cfg(feature = "dtype-array")] + pub fn try_array(&self) -> Option<&ArrayChunked> { + self.as_materialized_series().try_array() + } + #[cfg(feature = "dtype-categorical")] + pub fn try_categorical(&self) -> Option<&CategoricalChunked> { + self.as_materialized_series().try_categorical() + } + #[cfg(feature = "dtype-date")] + pub fn try_date(&self) -> Option<&DateChunked> { + self.as_materialized_series().try_date() + } + #[cfg(feature = "dtype-duration")] + pub fn try_duration(&self) -> Option<&DurationChunked> { + self.as_materialized_series().try_duration() + } + // # To Chunked Arrays pub fn bool(&self) -> PolarsResult<&BooleanChunked> { self.as_materialized_series().bool() diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index ce473a4d60fb..7f61f99895f4 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -517,37 +517,33 @@ unsafe fn to_physical_and_dtype( to_physical_and_dtype(out, md) }, #[cfg(feature = "dtype-array")] - #[allow(unused_variables)] ArrowDataType::FixedSizeList(field, size) => { - feature_gated!("dtype-array", { - let values = arrays - .iter() - .map(|arr| { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.values().clone() - }) - .collect::>(); + let values = arrays + .iter() + .map(|arr| { + let arr = arr.as_any().downcast_ref::().unwrap(); + arr.values().clone() + }) + .collect::>(); - let (converted_values, dtype) = - to_physical_and_dtype(values, Some(&field.metadata)); + let (converted_values, dtype) = to_physical_and_dtype(values, Some(&field.metadata)); - let arrays = arrays - .iter() - .zip(converted_values) - .map(|(arr, values)| { - let arr = arr.as_any().downcast_ref::().unwrap(); - - let dtype = - FixedSizeListArray::default_datatype(values.dtype().clone(), *size); - Box::from(FixedSizeListArray::new( - dtype, - values, - arr.validity().cloned(), - )) as ArrayRef - }) - .collect(); - (arrays, DataType::Array(Box::new(dtype), *size)) - }) + let arrays = arrays + .iter() + .zip(converted_values) + .map(|(arr, values)| { + let arr = arr.as_any().downcast_ref::().unwrap(); + + let dtype = FixedSizeListArray::default_datatype(values.dtype().clone(), *size); + Box::from(FixedSizeListArray::new( + dtype, + arr.len(), + values, + arr.validity().cloned(), + )) as ArrayRef + }) + .collect(); + (arrays, DataType::Array(Box::new(dtype), *size)) }, ArrowDataType::LargeList(field) => { let values = arrays diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index a41c5822283c..ce9bcffba2f0 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -825,6 +825,17 @@ impl Series { unsafe { Ok(self.take_unchecked(&idx)) } } + pub fn try_idx(&self) -> Option<&IdxCa> { + #[cfg(feature = "bigidx")] + { + self.try_u64() + } + #[cfg(not(feature = "bigidx"))] + { + self.try_u32() + } + } + pub fn idx(&self) -> PolarsResult<&IdxCa> { #[cfg(feature = "bigidx")] { diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs index ce57e42c610c..2189fc319b5e 100644 --- a/crates/polars-core/src/series/ops/downcast.rs +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -1,36 +1,190 @@ use crate::prelude::*; use crate::series::implementations::null::NullChunked; -macro_rules! unpack_chunked { - ($series:expr, $expected:pat => $ca:ty, $name:expr) => { +macro_rules! unpack_chunked_err { + ($series:expr => $name:expr) => { + polars_err!(SchemaMismatch: "invalid series dtype: expected `{}`, got `{}`", $name, $series.dtype()) + }; +} + +macro_rules! try_unpack_chunked { + ($series:expr, $expected:pat => $ca:ty) => { match $series.dtype() { $expected => { // Check downcast in debug compiles #[cfg(debug_assertions)] { - Ok($series.as_ref().as_any().downcast_ref::<$ca>().unwrap()) + Some($series.as_ref().as_any().downcast_ref::<$ca>().unwrap()) } #[cfg(not(debug_assertions))] unsafe { - Ok(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) + Some(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) } }, - dt => polars_bail!( - SchemaMismatch: "invalid series dtype: expected `{}`, got `{}`", $name, dt, - ), + _ => None, } }; } impl Series { + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int8]` + pub fn try_i8(&self) -> Option<&Int8Chunked> { + try_unpack_chunked!(self, DataType::Int8 => Int8Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int16]` + pub fn try_i16(&self) -> Option<&Int16Chunked> { + try_unpack_chunked!(self, DataType::Int16 => Int16Chunked) + } + + /// Unpack to [`ChunkedArray`] + /// ``` + /// # use polars_core::prelude::*; + /// let s = Series::new("foo".into(), [1i32 ,2, 3]); + /// let s_squared: Series = s.i32() + /// .unwrap() + /// .into_iter() + /// .map(|opt_v| { + /// match opt_v { + /// Some(v) => Some(v * v), + /// None => None, // null value + /// } + /// }).collect(); + /// ``` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int32]` + pub fn try_i32(&self) -> Option<&Int32Chunked> { + try_unpack_chunked!(self, DataType::Int32 => Int32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int64]` + pub fn try_i64(&self) -> Option<&Int64Chunked> { + try_unpack_chunked!(self, DataType::Int64 => Int64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float32]` + pub fn try_f32(&self) -> Option<&Float32Chunked> { + try_unpack_chunked!(self, DataType::Float32 => Float32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float64]` + pub fn try_f64(&self) -> Option<&Float64Chunked> { + try_unpack_chunked!(self, DataType::Float64 => Float64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt8]` + pub fn try_u8(&self) -> Option<&UInt8Chunked> { + try_unpack_chunked!(self, DataType::UInt8 => UInt8Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt16]` + pub fn try_u16(&self) -> Option<&UInt16Chunked> { + try_unpack_chunked!(self, DataType::UInt16 => UInt16Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt32]` + pub fn try_u32(&self) -> Option<&UInt32Chunked> { + try_unpack_chunked!(self, DataType::UInt32 => UInt32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt64]` + pub fn try_u64(&self) -> Option<&UInt64Chunked> { + try_unpack_chunked!(self, DataType::UInt64 => UInt64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Boolean]` + pub fn try_bool(&self) -> Option<&BooleanChunked> { + try_unpack_chunked!(self, DataType::Boolean => BooleanChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::String]` + pub fn try_str(&self) -> Option<&StringChunked> { + try_unpack_chunked!(self, DataType::String => StringChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` + pub fn try_binary(&self) -> Option<&BinaryChunked> { + try_unpack_chunked!(self, DataType::Binary => BinaryChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` + pub fn try_binary_offset(&self) -> Option<&BinaryOffsetChunked> { + try_unpack_chunked!(self, DataType::BinaryOffset => BinaryOffsetChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Time]` + #[cfg(feature = "dtype-time")] + pub fn try_time(&self) -> Option<&TimeChunked> { + try_unpack_chunked!(self, DataType::Time => TimeChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Date]` + #[cfg(feature = "dtype-date")] + pub fn try_date(&self) -> Option<&DateChunked> { + try_unpack_chunked!(self, DataType::Date => DateChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Datetime]` + #[cfg(feature = "dtype-datetime")] + pub fn try_datetime(&self) -> Option<&DatetimeChunked> { + try_unpack_chunked!(self, DataType::Datetime(_, _) => DatetimeChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Duration]` + #[cfg(feature = "dtype-duration")] + pub fn try_duration(&self) -> Option<&DurationChunked> { + try_unpack_chunked!(self, DataType::Duration(_) => DurationChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Decimal]` + #[cfg(feature = "dtype-decimal")] + pub fn try_decimal(&self) -> Option<&DecimalChunked> { + try_unpack_chunked!(self, DataType::Decimal(_, _) => DecimalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype list + pub fn try_list(&self) -> Option<&ListChunked> { + try_unpack_chunked!(self, DataType::List(_) => ListChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Array]` + #[cfg(feature = "dtype-array")] + pub fn try_array(&self) -> Option<&ArrayChunked> { + try_unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Categorical]` + #[cfg(feature = "dtype-categorical")] + pub fn try_categorical(&self) -> Option<&CategoricalChunked> { + try_unpack_chunked!(self, DataType::Categorical(_, _) | DataType::Enum(_, _) => CategoricalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Struct]` + #[cfg(feature = "dtype-struct")] + pub fn try_struct(&self) -> Option<&StructChunked> { + #[cfg(debug_assertions)] + { + if let DataType::Struct(_) = self.dtype() { + let any = self.as_any(); + assert!(any.is::()); + } + } + try_unpack_chunked!(self, DataType::Struct(_) => StructChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Null]` + pub fn try_null(&self) -> Option<&NullChunked> { + try_unpack_chunked!(self, DataType::Null => NullChunked) + } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int8]` pub fn i8(&self) -> PolarsResult<&Int8Chunked> { - unpack_chunked!(self, DataType::Int8 => Int8Chunked, "Int8") + self.try_i8() + .ok_or_else(|| unpack_chunked_err!(self => "Int8")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int16]` pub fn i16(&self) -> PolarsResult<&Int16Chunked> { - unpack_chunked!(self, DataType::Int16 => Int16Chunked, "Int16") + self.try_i16() + .ok_or_else(|| unpack_chunked_err!(self => "Int16")) } /// Unpack to [`ChunkedArray`] @@ -49,109 +203,129 @@ impl Series { /// ``` /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int32]` pub fn i32(&self) -> PolarsResult<&Int32Chunked> { - unpack_chunked!(self, DataType::Int32 => Int32Chunked, "Int32") + self.try_i32() + .ok_or_else(|| unpack_chunked_err!(self => "Int32")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int64]` pub fn i64(&self) -> PolarsResult<&Int64Chunked> { - unpack_chunked!(self, DataType::Int64 => Int64Chunked, "Int64") + self.try_i64() + .ok_or_else(|| unpack_chunked_err!(self => "Int64")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float32]` pub fn f32(&self) -> PolarsResult<&Float32Chunked> { - unpack_chunked!(self, DataType::Float32 => Float32Chunked, "Float32") + self.try_f32() + .ok_or_else(|| unpack_chunked_err!(self => "Float32")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float64]` pub fn f64(&self) -> PolarsResult<&Float64Chunked> { - unpack_chunked!(self, DataType::Float64 => Float64Chunked, "Float64") + self.try_f64() + .ok_or_else(|| unpack_chunked_err!(self => "Float64")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt8]` pub fn u8(&self) -> PolarsResult<&UInt8Chunked> { - unpack_chunked!(self, DataType::UInt8 => UInt8Chunked, "UInt8") + self.try_u8() + .ok_or_else(|| unpack_chunked_err!(self => "UInt8")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt16]` pub fn u16(&self) -> PolarsResult<&UInt16Chunked> { - unpack_chunked!(self, DataType::UInt16 => UInt16Chunked, "UInt16") + self.try_u16() + .ok_or_else(|| unpack_chunked_err!(self => "UInt16")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt32]` pub fn u32(&self) -> PolarsResult<&UInt32Chunked> { - unpack_chunked!(self, DataType::UInt32 => UInt32Chunked, "UInt32") + self.try_u32() + .ok_or_else(|| unpack_chunked_err!(self => "UInt32")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt64]` pub fn u64(&self) -> PolarsResult<&UInt64Chunked> { - unpack_chunked!(self, DataType::UInt64 => UInt64Chunked, "UInt64") + self.try_u64() + .ok_or_else(|| unpack_chunked_err!(self => "UInt64")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Boolean]` pub fn bool(&self) -> PolarsResult<&BooleanChunked> { - unpack_chunked!(self, DataType::Boolean => BooleanChunked, "Boolean") + self.try_bool() + .ok_or_else(|| unpack_chunked_err!(self => "Boolean")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::String]` pub fn str(&self) -> PolarsResult<&StringChunked> { - unpack_chunked!(self, DataType::String => StringChunked, "String") + self.try_str() + .ok_or_else(|| unpack_chunked_err!(self => "String")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` pub fn binary(&self) -> PolarsResult<&BinaryChunked> { - unpack_chunked!(self, DataType::Binary => BinaryChunked, "Binary") + self.try_binary() + .ok_or_else(|| unpack_chunked_err!(self => "Binary")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` pub fn binary_offset(&self) -> PolarsResult<&BinaryOffsetChunked> { - unpack_chunked!(self, DataType::BinaryOffset => BinaryOffsetChunked, "BinaryOffset") + self.try_binary_offset() + .ok_or_else(|| unpack_chunked_err!(self => "BinaryOffset")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Time]` #[cfg(feature = "dtype-time")] pub fn time(&self) -> PolarsResult<&TimeChunked> { - unpack_chunked!(self, DataType::Time => TimeChunked, "Time") + self.try_time() + .ok_or_else(|| unpack_chunked_err!(self => "Time")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Date]` #[cfg(feature = "dtype-date")] pub fn date(&self) -> PolarsResult<&DateChunked> { - unpack_chunked!(self, DataType::Date => DateChunked, "Date") + self.try_date() + .ok_or_else(|| unpack_chunked_err!(self => "Date")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Datetime]` #[cfg(feature = "dtype-datetime")] pub fn datetime(&self) -> PolarsResult<&DatetimeChunked> { - unpack_chunked!(self, DataType::Datetime(_, _) => DatetimeChunked, "Datetime") + self.try_datetime() + .ok_or_else(|| unpack_chunked_err!(self => "Datetime")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Duration]` #[cfg(feature = "dtype-duration")] pub fn duration(&self) -> PolarsResult<&DurationChunked> { - unpack_chunked!(self, DataType::Duration(_) => DurationChunked, "Duration") + self.try_duration() + .ok_or_else(|| unpack_chunked_err!(self => "Duration")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Decimal]` #[cfg(feature = "dtype-decimal")] pub fn decimal(&self) -> PolarsResult<&DecimalChunked> { - unpack_chunked!(self, DataType::Decimal(_, _) => DecimalChunked, "Decimal") + self.try_decimal() + .ok_or_else(|| unpack_chunked_err!(self => "Decimal")) } /// Unpack to [`ChunkedArray`] of dtype list pub fn list(&self) -> PolarsResult<&ListChunked> { - unpack_chunked!(self, DataType::List(_) => ListChunked, "List") + self.try_list() + .ok_or_else(|| unpack_chunked_err!(self => "List")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Array]` #[cfg(feature = "dtype-array")] pub fn array(&self) -> PolarsResult<&ArrayChunked> { - unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked, "FixedSizeList") + self.try_array() + .ok_or_else(|| unpack_chunked_err!(self => "FixedSizeList")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Categorical]` #[cfg(feature = "dtype-categorical")] pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { - unpack_chunked!(self, DataType::Categorical(_, _) | DataType::Enum(_, _) => CategoricalChunked, "Enum | Categorical") + self.try_categorical() + .ok_or_else(|| unpack_chunked_err!(self => "Enum | Categorical")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Struct]` @@ -164,11 +338,14 @@ impl Series { assert!(any.is::()); } } - unpack_chunked!(self, DataType::Struct(_) => StructChunked, "Struct") + + self.try_struct() + .ok_or_else(|| unpack_chunked_err!(self => "Struct")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Null]` pub fn null(&self) -> PolarsResult<&NullChunked> { - unpack_chunked!(self, DataType::Null => NullChunked, "Null") + self.try_null() + .ok_or_else(|| unpack_chunked_err!(self => "Null")) } } diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index fdc6b6091058..544754755e6e 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -96,63 +96,90 @@ impl Series { let mut total_dim_size = 1; let mut num_infers = 0; - for (index, &dim) in dimensions.iter().enumerate() { + for &dim in dimensions { match dim { - ReshapeDimension::Infer => { - polars_ensure!( - num_infers == 0, - InvalidOperation: "can only specify one inferred dimension" - ); - num_infers += 1; - }, - ReshapeDimension::Specified(dim) => { - let dim = dim.get(); - - if dim > 0 { - total_dim_size *= dim as usize - } else { - polars_ensure!( - index == 0, - InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}", - format_tuple!(dimensions) - ); - total_dim_size = 0; - // We can early exit here, as empty arrays will error with multiple dimensions, - // and non-empty arrays will error when the first dimension is zero. - break; - } - }, + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, } } + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + if size == 0 { - if dimensions.len() > 1 || (num_infers == 0 && total_dim_size != 0) { - polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dimensions)) - } - } else if total_dim_size == 0 { - polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions)) - } else { polars_ensure!( - size % total_dim_size == 0, - InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), ); + + let mut prev_arrow_dtype = leaf_array + .dtype() + .to_physical() + .to_arrow(CompatLevel::newest()); + let mut prev_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array.chunks()[0].clone(); + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + leaf_array.name().clone(), + vec![prev_array], + &prev_dtype, + ) + }); } + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + let leaf_array = leaf_array.rechunk(); + let mut prev_arrow_dtype = leaf_array + .dtype() + .to_physical() + .to_arrow(CompatLevel::newest()); let mut prev_dtype = leaf_array.dtype().clone(); let mut prev_array = leaf_array.chunks()[0].clone(); // We pop the outer dimension as that is the height of the series. - for idx in (1..dimensions.len()).rev() { + for dim in dimensions[1..].iter().rev() { // Infer dimension if needed - let dim = dimensions[idx].get_or_infer_with(|| { - debug_assert!(num_infers > 0); - (size / total_dim_size) as u64 - }); + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); prev_array = FixedSizeListArray::new( - prev_dtype.to_arrow(CompatLevel::newest()), + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, prev_array, None, ) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs index 520f7f8596e1..3bc1beb30973 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs @@ -45,7 +45,7 @@ pub fn create_list( nested: &mut NestedState, values: Box, ) -> Box { - let (mut offsets, validity) = nested.pop().unwrap(); + let (length, mut offsets, validity) = nested.pop().unwrap(); let validity = validity.and_then(freeze_validity); match dtype.to_logical_type() { ArrowDataType::List(_) => { @@ -75,7 +75,7 @@ pub fn create_list( )) }, ArrowDataType::FixedSizeList(_, _) => { - Box::new(FixedSizeListArray::new(dtype, values, validity)) + Box::new(FixedSizeListArray::new(dtype, length, values, validity)) }, _ => unreachable!(), } @@ -87,7 +87,7 @@ pub fn create_map( nested: &mut NestedState, values: Box, ) -> Box { - let (mut offsets, validity) = nested.pop().unwrap(); + let (_, mut offsets, validity) = nested.pop().unwrap(); match dtype.to_logical_type() { ArrowDataType::Map(_, _) => { offsets.push(values.len() as i64); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index 114eeef67341..b5b083f8b882 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -403,7 +403,7 @@ pub fn columns_to_iter_recursive( let (mut nested, last_array) = field_to_nested_array(init.clone(), &mut columns, &mut types, last_field)?; debug_assert!(matches!(nested.last().unwrap(), NestedContent::Struct)); - let (_, struct_validity) = nested.pop().unwrap(); + let (_, _, struct_validity) = nested.pop().unwrap(); let mut field_arrays = Vec::>::with_capacity(fields.len()); field_arrays.push(last_array); @@ -416,7 +416,7 @@ pub fn columns_to_iter_recursive( { debug_assert!(matches!(_nested.last().unwrap(), NestedContent::Struct)); debug_assert_eq!( - _nested.pop().unwrap().1.and_then(freeze_validity), + _nested.pop().unwrap().2.and_then(freeze_validity), struct_validity.clone().and_then(freeze_validity), ); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index ad542cf05753..ab769848ca92 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -87,12 +87,17 @@ impl Nested { } } - fn take(mut self) -> (Vec, Option) { + fn take(mut self) -> (usize, Vec, Option) { if !matches!(self.content, NestedContent::Primitive) { if let Some(validity) = self.validity.as_mut() { validity.extend_constant(self.num_valids, true); validity.extend_constant(self.num_invalids, false); } + + debug_assert!(self + .validity + .as_ref() + .map_or(true, |v| v.len() == self.length)); } self.num_valids = 0; @@ -101,11 +106,11 @@ impl Nested { match self.content { NestedContent::Primitive => { debug_assert!(self.validity.map_or(true, |validity| validity.is_empty())); - (Vec::new(), None) + (self.length, Vec::new(), None) }, - NestedContent::List { offsets } => (offsets, self.validity), - NestedContent::FixedSizeList { .. } => (Vec::new(), self.validity), - NestedContent::Struct => (Vec::new(), self.validity), + NestedContent::List { offsets } => (self.length, offsets, self.validity), + NestedContent::FixedSizeList { .. } => (self.length, Vec::new(), self.validity), + NestedContent::Struct => (self.length, Vec::new(), self.validity), } } @@ -254,7 +259,7 @@ impl NestedState { Self { nested } } - pub fn pop(&mut self) -> Option<(Vec, Option)> { + pub fn pop(&mut self) -> Option<(usize, Vec, Option)> { Some(self.nested.pop()?.take()) } diff --git a/crates/polars-parquet/src/arrow/read/statistics/list.rs b/crates/polars-parquet/src/arrow/read/statistics/list.rs index 54f308c94f4d..baea27289124 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/list.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/list.rs @@ -64,6 +64,7 @@ impl MutableArray for DynMutableListArray { }, ArrowDataType::FixedSizeList(field, _) => Box::new(FixedSizeListArray::new( ArrowDataType::FixedSizeList(field.clone(), inner.len()), + 1, inner, None, )), diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 9159c2b6b3ee..9225dfa4cae0 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -392,7 +392,7 @@ pub(super) fn concat(s: &mut [Column]) -> PolarsResult> { let mut first = std::mem::take(&mut s[0]); let other = &s[1..]; - let mut first_ca = match first.list().ok() { + let mut first_ca = match first.try_list() { Some(ca) => ca, None => { first = first diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs index 5e3e4174f667..db0c4ccfd802 100644 --- a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs @@ -12,6 +12,7 @@ fn data() -> FixedSizeListArray { Box::new(Field::new("a".into(), values.dtype().clone(), true)), 2, ), + 2, values.boxed(), Some([true, false].into()), ) @@ -87,6 +88,7 @@ fn wrong_size() { Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), 2 ), + 2, values.boxed(), None ) @@ -95,12 +97,13 @@ fn wrong_size() { #[test] fn wrong_len() { - let values = Int32Array::from_slice([10, 20, 0]); + let values = Int32Array::from_slice([10, 20, 0, 0]); assert!(FixedSizeListArray::try_new( ArrowDataType::FixedSizeList( Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), 2 ), + 2, values.boxed(), Some([true, false, false].into()), // it should be 2 ) @@ -109,11 +112,12 @@ fn wrong_len() { #[test] fn wrong_dtype() { - let values = Int32Array::from_slice([10, 20, 0]); + let values = Int32Array::from_slice([10, 20, 0, 0]); assert!(FixedSizeListArray::try_new( ArrowDataType::Binary, + 2, values.boxed(), - Some([true, false, false].into()), // it should be 2 + Some([true, false, false, false].into()), ) .is_err()); } diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs index 45e19d194a46..075e5179e1ca 100644 --- a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -27,6 +27,6 @@ fn fixed_size_list() { 3, ); let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); - let a = FixedSizeListArray::new(dtype, values, None); + let a = FixedSizeListArray::new(dtype, 2, values, None); assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); } diff --git a/py-polars/tests/unit/constructors/test_series.py b/py-polars/tests/unit/constructors/test_series.py index c31a5b48ce68..9c6346bf5395 100644 --- a/py-polars/tests/unit/constructors/test_series.py +++ b/py-polars/tests/unit/constructors/test_series.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from datetime import date, datetime, timedelta from typing import TYPE_CHECKING, Any @@ -9,7 +8,6 @@ import pytest import polars as pl -from polars.exceptions import InvalidOperationError from polars.testing.asserts.series import assert_series_equal if TYPE_CHECKING: @@ -157,11 +155,10 @@ def test_series_init_pandas_timestamp_18127() -> None: def test_series_init_np_2d_zero_zero_shape() -> None: arr = np.array([]).reshape(0, 0) - with pytest.raises( - InvalidOperationError, - match=re.escape("cannot reshape empty array into shape (0, 0)"), - ): - pl.Series(arr) + assert_series_equal( + pl.Series("a", arr), + pl.Series("a", [], pl.Array(pl.Float64, 0)), + ) def test_list_null_constructor_schema() -> None: diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index 6c4f240803bf..b578266b0c6f 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -342,3 +342,44 @@ def test_array_invalid_physical_type_18920() -> None: expected = expected_s.to_frame().with_columns(pl.col.x.list.to_array(2)) assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + "fn", + [ + "__add__", + "__sub__", + "__mul__", + "__truediv__", + "__mod__", + "__eq__", + "__ne__", + ], +) +def test_zero_width_array(fn: str) -> None: + series_f = getattr(pl.Series, fn) + expr_f = getattr(pl.Expr, fn) + + values = [ + [ + [[]], + [None], + ], + [ + [[], []], + [None, []], + [[], None], + [None, None], + ], + ] + + for vs in values: + for lhs in vs: + for rhs in vs: + a = pl.Series("a", lhs, pl.Array(pl.Int8, 0)) + b = pl.Series("b", rhs, pl.Array(pl.Int8, 0)) + + series_f(a, b) + + df = pl.concat([a.to_frame(), b.to_frame()], how="horizontal") + df.select(c=expr_f(pl.col.a, pl.col.b)) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index b53766ae2c2c..39e515bf31d8 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -876,9 +876,8 @@ def test_parquet_array_dtype_nulls() -> None: ([[1, 2, 3]], pl.Array(pl.Int64, 3)), ([[1, None, 3], None, [1, 2, None]], pl.Array(pl.Int64, 3)), ([[1, 2], None, [None, 3]], pl.Array(pl.Int64, 2)), - # @TODO: Enable when zero-width arrays are enabled - # ([[], [], []], pl.Array(pl.Int64, 0)), - # ([[], None, []], pl.Array(pl.Int64, 0)), + ([[], [], []], pl.Array(pl.Int64, 0)), + ([[], None, []], pl.Array(pl.Int64, 0)), ( [[[1, 5, 2], [42, 13, 37]], [[1, 2, 3], [5, 2, 3]], [[1, 2, 1], [3, 1, 3]]], pl.Array(pl.Array(pl.Int8, 3), 2), @@ -924,7 +923,7 @@ def test_parquet_array_dtype_nulls() -> None: [[]], [[None]], [[[None], None]], - [[[None], []]], + [[[None], [None]]], [[[[None]], [[[1]]]]], [[[[[None]]]]], [[[[[1]]]]], @@ -940,12 +939,6 @@ def test_complex_types(series: list[Any], dtype: pl.DataType) -> None: test_round_trip(df) -@pytest.mark.xfail -def test_placeholder_zero_array() -> None: - # @TODO: if this does not fail anymore please enable the upper test-cases - pl.Series([[]], dtype=pl.Array(pl.Int8, 0)) - - @pytest.mark.write_disk def test_parquet_array_statistics(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) diff --git a/py-polars/tests/unit/operations/test_reshape.py b/py-polars/tests/unit/operations/test_reshape.py index 12ddfed628c5..d6d000769637 100644 --- a/py-polars/tests/unit/operations/test_reshape.py +++ b/py-polars/tests/unit/operations/test_reshape.py @@ -66,7 +66,7 @@ def test_reshape_invalid_zero_dimension() -> None: with pytest.raises( InvalidOperationError, match=re.escape( - f"cannot reshape array into shape containing a zero dimension after the first: {display_shape(shape)}" + f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}" ), ): s.reshape(shape) @@ -100,24 +100,14 @@ def test_reshape_empty_valid_1d(shape: tuple[int, ...]) -> None: assert_series_equal(out, s) -@pytest.mark.parametrize("shape", [(0, 1), (1, -1), (-1, 1)]) -def test_reshape_empty_invalid_2d(shape: tuple[int, ...]) -> None: - s = pl.Series("a", [], dtype=pl.Int64) - with pytest.raises( - InvalidOperationError, - match=re.escape( - f"cannot reshape empty array into shape {display_shape(shape)}" - ), - ): - s.reshape(shape) - - @pytest.mark.parametrize("shape", [(1,), (2,)]) def test_reshape_empty_invalid_1d(shape: tuple[int, ...]) -> None: s = pl.Series("a", [], dtype=pl.Int64) with pytest.raises( InvalidOperationError, - match=re.escape(f"cannot reshape empty array into shape ({shape[0]})"), + match=re.escape( + f"cannot reshape empty array into shape without zero dimension: ({shape[0]})" + ), ): s.reshape(shape) @@ -131,3 +121,23 @@ def test_array_ndarray_reshape() -> None: n = n[0] s = s[0] assert (n[0] == s[0].to_numpy()).all() + + +@pytest.mark.parametrize( + "shape", + [ + (0, 1), + (1, 0), + (-1, 10, 20, 10), + (-1, 1, 0), + (10, 1, 0), + (10, 0, 1, 0), + (10, 0, 1), + (42, 2, 3, 4, 0, 2, 3, 4), + (42, 1, 1, 1, 0), + ], +) +def test_reshape_empty(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [], dtype=pl.Int64) + expected_len = max(shape[0], 0) + assert s.reshape(shape).len() == expected_len