From 01217d214aac8e5fd4266b1e04d30ad22cd552a4 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 27 Dec 2023 18:07:41 +0800 Subject: [PATCH] feat(rust): Impl serde for array dtype (#13168) --- crates/polars-arrow/src/array/null.rs | 2 +- .../src/legacy/array/fixed_size_list.rs | 33 ++++++++- crates/polars-arrow/src/legacy/array/list.rs | 44 +----------- crates/polars-arrow/src/legacy/array/mod.rs | 57 ++++++++++++++- .../chunked_array/builder/fixed_size_list.rs | 25 ++++++- crates/polars-core/src/datatypes/_serde.rs | 6 ++ crates/polars-core/src/datatypes/dtype.rs | 2 + crates/polars-core/src/serde/chunked_array.rs | 2 + crates/polars-core/src/serde/series.rs | 34 +++++++++ py-polars/src/series/construction.rs | 37 ++++++---- py-polars/tests/unit/io/test_json.py | 69 ++++++++++++++++++- py-polars/tests/unit/test_serde.py | 14 ++++ 12 files changed, 261 insertions(+), 64 deletions(-) diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 3768e5e9a0fa..82269e3c2066 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -21,7 +21,7 @@ impl NullArray { /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. pub fn try_new(data_type: ArrowDataType, length: usize) -> PolarsResult { if data_type.to_physical_type() != PhysicalType::Null { - polars_bail!(ComputeError: "NullArray can only be initialized with a DataType whose physical type is Boolean"); + polars_bail!(ComputeError: "NullArray can only be initialized with a DataType whose physical type is Null"); } Ok(Self { data_type, length }) diff --git a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs index b9ad4dbfefa0..06c41b75e3e1 100644 --- a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs @@ -1,8 +1,9 @@ use polars_error::PolarsResult; -use crate::array::{ArrayRef, FixedSizeListArray}; +use crate::array::{ArrayRef, FixedSizeListArray, NullArray}; use crate::bitmap::MutableBitmap; use crate::datatypes::ArrowDataType; +use crate::legacy::array::{convert_inner_type, is_nested_null}; use crate::legacy::kernels::concatenate::concatenate_owned_unchecked; #[derive(Default)] @@ -34,6 +35,8 @@ impl AnonymousBuilder { } pub fn push_null(&mut self) { + self.arrays + .push(NullArray::new(ArrowDataType::Null, self.width).boxed()); match &mut self.validity { Some(validity) => validity.push(false), None => self.init_validity(), @@ -48,8 +51,32 @@ impl AnonymousBuilder { } pub fn finish(self, inner_dtype: Option<&ArrowDataType>) -> PolarsResult { - let values = concatenate_owned_unchecked(&self.arrays)?; - let inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].data_type()); + let mut inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].data_type()); + + if is_nested_null(inner_dtype) { + for arr in &self.arrays { + if !is_nested_null(arr.data_type()) { + inner_dtype = arr.data_type(); + break; + } + } + }; + + // convert nested null arrays to the correct dtype. + let arrays = self + .arrays + .iter() + .map(|arr| { + if is_nested_null(arr.data_type()) { + convert_inner_type(&**arr, inner_dtype) + } else { + arr.to_boxed() + } + }) + .collect::>(); + + let values = concatenate_owned_unchecked(&arrays)?; + let data_type = FixedSizeListArray::default_datatype(inner_dtype.clone(), self.width); Ok(FixedSizeListArray::new( data_type, diff --git a/crates/polars-arrow/src/legacy/array/list.rs b/crates/polars-arrow/src/legacy/array/list.rs index 9bd2434513d1..7dcb6e5de3a2 100644 --- a/crates/polars-arrow/src/legacy/array/list.rs +++ b/crates/polars-arrow/src/legacy/array/list.rs @@ -1,9 +1,10 @@ use polars_error::PolarsResult; -use crate::array::{new_null_array, Array, ArrayRef, ListArray, NullArray, StructArray}; +use crate::array::{new_null_array, Array, ArrayRef, ListArray, NullArray}; use crate::bitmap::MutableBitmap; use crate::compute::concatenate; use crate::datatypes::ArrowDataType; +use crate::legacy::array::is_nested_null; use crate::legacy::kernels::concatenate::concatenate_owned_unchecked; use crate::legacy::prelude::*; use crate::offset::Offsets; @@ -162,44 +163,3 @@ impl<'a> AnonymousBuilder<'a> { )) } } - -fn is_nested_null(data_type: &ArrowDataType) -> bool { - match data_type { - ArrowDataType::Null => true, - ArrowDataType::LargeList(field) => is_nested_null(field.data_type()), - ArrowDataType::Struct(fields) => { - fields.iter().all(|field| is_nested_null(field.data_type())) - }, - _ => false, - } -} - -/// Cast null arrays to inner type and ensure that all offsets remain correct -pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { - match dtype { - ArrowDataType::LargeList(field) => { - let array = array.as_any().downcast_ref::().unwrap(); - let inner = array.values(); - let new_values = convert_inner_type(inner.as_ref(), field.data_type()); - let dtype = LargeListArray::default_datatype(new_values.data_type().clone()); - LargeListArray::new( - dtype, - array.offsets().clone(), - new_values, - array.validity().cloned(), - ) - .boxed() - }, - ArrowDataType::Struct(fields) => { - let array = array.as_any().downcast_ref::().unwrap(); - let inner = array.values(); - let new_values = inner - .iter() - .zip(fields) - .map(|(arr, field)| convert_inner_type(arr.as_ref(), field.data_type())) - .collect::>(); - StructArray::new(dtype.clone(), new_values, array.validity().cloned()).boxed() - }, - _ => new_null_array(dtype.clone(), array.len()), - } -} diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs index 5e472ceb194b..594766e89929 100644 --- a/crates/polars-arrow/src/legacy/array/mod.rs +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -1,4 +1,7 @@ -use crate::array::{Array, BinaryArray, BooleanArray, ListArray, PrimitiveArray, Utf8Array}; +use crate::array::{ + new_null_array, Array, BinaryArray, BooleanArray, FixedSizeListArray, ListArray, + PrimitiveArray, StructArray, Utf8Array, +}; use crate::bitmap::MutableBitmap; use crate::datatypes::ArrowDataType; use crate::legacy::utils::CustomIterTools; @@ -16,6 +19,8 @@ pub mod utf8; pub use slice::*; +use crate::legacy::prelude::LargeListArray; + macro_rules! iter_to_values { ($iterator:expr, $validity:expr, $offsets:expr, $length_so_far:expr) => {{ $iterator @@ -206,3 +211,53 @@ pub trait PolarsArray: Array { } impl PolarsArray for A {} + +fn is_nested_null(data_type: &ArrowDataType) -> bool { + match data_type { + ArrowDataType::Null => true, + ArrowDataType::LargeList(field) => is_nested_null(field.data_type()), + ArrowDataType::FixedSizeList(field, _) => is_nested_null(field.data_type()), + ArrowDataType::Struct(fields) => { + fields.iter().all(|field| is_nested_null(field.data_type())) + }, + _ => false, + } +} + +/// Cast null arrays to inner type and ensure that all offsets remain correct +pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { + match dtype { + ArrowDataType::LargeList(field) => { + let array = array.as_any().downcast_ref::().unwrap(); + let inner = array.values(); + let new_values = convert_inner_type(inner.as_ref(), field.data_type()); + let dtype = LargeListArray::default_datatype(new_values.data_type().clone()); + LargeListArray::new( + dtype, + array.offsets().clone(), + new_values, + array.validity().cloned(), + ) + .boxed() + }, + ArrowDataType::FixedSizeList(field, width) => { + let array = array.as_any().downcast_ref::().unwrap(); + let inner = array.values(); + let new_values = convert_inner_type(inner.as_ref(), field.data_type()); + let dtype = + FixedSizeListArray::default_datatype(new_values.data_type().clone(), *width); + FixedSizeListArray::new(dtype, new_values, array.validity().cloned()).boxed() + }, + ArrowDataType::Struct(fields) => { + let array = array.as_any().downcast_ref::().unwrap(); + let inner = array.values(); + let new_values = inner + .iter() + .zip(fields) + .map(|(arr, field)| convert_inner_type(arr.as_ref(), field.data_type())) + .collect::>(); + StructArray::new(dtype.clone(), new_values, array.validity().cloned()).boxed() + }, + _ => new_null_array(dtype.clone(), array.len()), + } +} diff --git a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs index e08fcefb22fb..802dd5e5e1c2 100644 --- a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs +++ b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs @@ -12,16 +12,25 @@ pub(crate) struct FixedSizeListNumericBuilder { inner: Option>>, width: usize, name: SmartString, + logical_dtype: DataType, } impl FixedSizeListNumericBuilder { - pub(crate) fn new(name: &str, width: usize, capacity: usize) -> Self { + /// SAFETY + /// The caller must ensure that the physical numerical type match logical type. + pub(crate) unsafe fn new( + name: &str, + width: usize, + capacity: usize, + logical_dtype: DataType, + ) -> Self { let mp = MutablePrimitiveArray::::with_capacity(capacity * width); let inner = Some(MutableFixedSizeListArray::new(mp, width)); Self { inner, width, name: name.into(), + logical_dtype, } } } @@ -68,7 +77,14 @@ impl FixedSizeListBuilder for FixedSizeListNumericBuilder { fn finish(&mut self) -> ArrayChunked { let arr: FixedSizeListArray = self.inner.take().unwrap().into(); - ChunkedArray::with_chunk(self.name.as_str(), arr) + // SAFETY: physical type matches the logical + unsafe { + ChunkedArray::from_chunks_and_dtype( + self.name.as_str(), + vec![Box::new(arr)], + DataType::Array(Box::new(self.logical_dtype.clone()), self.width), + ) + } } } @@ -124,7 +140,10 @@ pub(crate) fn get_fixed_size_list_builder( let builder = if phys_dtype.is_numeric() { with_match_physical_numeric_type!(phys_dtype, |$T| { - Box::new(FixedSizeListNumericBuilder::<$T>::new(name, width, capacity)) as Box + // SAFETY: physical type match logical type + unsafe { + Box::new(FixedSizeListNumericBuilder::<$T>::new(name, width, capacity,inner_type_logical.clone())) as Box + } }) } else { Box::new(AnonymousOwnedFixedSizeListBuilder::new( diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index 67a0f4d013f7..657c81dfdb76 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -53,6 +53,8 @@ pub enum SerializableDataType { /// A 64-bit time representing elapsed time since midnight in the given TimeUnit. Time, List(Box), + #[cfg(feature = "dtype-array")] + Array(Box, usize), Null, #[cfg(feature = "dtype-struct")] Struct(Vec), @@ -86,6 +88,8 @@ impl From<&DataType> for SerializableDataType { Duration(tu) => Self::Duration(*tu), Time => Self::Time, List(dt) => Self::List(Box::new(dt.as_ref().into())), + #[cfg(feature = "dtype-array")] + Array(dt, width) => Self::Array(Box::new(dt.as_ref().into()), *width), Null => Self::Null, Unknown => Self::Unknown, #[cfg(feature = "dtype-struct")] @@ -120,6 +124,8 @@ impl From for DataType { Duration(tu) => Self::Duration(tu), Time => Self::Time, List(dt) => Self::List(Box::new((*dt).into())), + #[cfg(feature = "dtype-array")] + Array(dt, width) => Self::Array(Box::new((*dt).into()), width), Null => Self::Null, Unknown => Self::Unknown, #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 5590cddbc4a1..ae37b4c74aa3 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -145,6 +145,8 @@ impl DataType { #[cfg(feature = "dtype-categorical")] Categorical(_, _) => UInt32, List(dt) => List(Box::new(dt.to_physical())), + #[cfg(feature = "dtype-array")] + Array(dt, width) => Array(Box::new(dt.to_physical()), *width), #[cfg(feature = "dtype-struct")] Struct(fields) => { let new_fields = fields diff --git a/crates/polars-core/src/serde/chunked_array.rs b/crates/polars-core/src/serde/chunked_array.rs index 5cff96b2cea4..b39166e9dfed 100644 --- a/crates/polars-core/src/serde/chunked_array.rs +++ b/crates/polars-core/src/serde/chunked_array.rs @@ -133,6 +133,8 @@ impl_serialize!(StringChunked); impl_serialize!(BooleanChunked); impl_serialize!(ListChunked); impl_serialize!(BinaryChunked); +#[cfg(feature = "dtype-array")] +impl_serialize!(ArrayChunked); #[cfg(feature = "dtype-categorical")] impl Serialize for CategoricalChunked { diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 46d020aed79d..74d69194fd49 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -4,6 +4,8 @@ use std::fmt::Formatter; use serde::de::{MapAccess, Visitor}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +#[cfg(feature = "dtype-array")] +use crate::chunked_array::builder::get_fixed_size_list_builder; use crate::chunked_array::builder::AnonymousListBuilder; use crate::chunked_array::Settings; use crate::prelude::*; @@ -25,6 +27,11 @@ impl Serialize for Series { let ca = self.list().unwrap(); ca.serialize(serializer) }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let ca = self.array().unwrap(); + ca.serialize(serializer) + }, DataType::Boolean => { let ca = self.bool().unwrap(); ca.serialize(serializer) @@ -213,6 +220,33 @@ impl<'de> Deserialize<'de> for Series { } Ok(lb.finish().into_series()) }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, width) => { + let values: Vec> = map.next_value()?; + let mut builder = + get_fixed_size_list_builder(&inner, values.len(), width, &name) + .map_err(|e| { + de::Error::custom(format!( + "could not get supported list builder: {e}" + )) + })?; + for value in &values { + if let Some(s) = value { + // we only have one chunk per series as we serialize it in this way. + let arr = &s.chunks()[0]; + // safety, we are within bounds + unsafe { + builder.push_unchecked(arr.as_ref(), 0); + } + } else { + // safety, we are within bounds + unsafe { + builder.push_null(); + } + } + } + Ok(builder.finish().into_series()) + }, DataType::Binary => { let values: Vec>> = map.next_value()?; Ok(Series::new(&name, values)) diff --git a/py-polars/src/series/construction.rs b/py-polars/src/series/construction.rs index 155c9b7c6b2c..61f99ab7f9dd 100644 --- a/py-polars/src/series/construction.rs +++ b/py-polars/src/series/construction.rs @@ -258,19 +258,30 @@ impl PySeries { Ok(series.into()) } else { let val = vec_extract_wrapped(val); - let series = Series::new(name, &val); - match series.dtype() { - DataType::List(list_inner) => { - let series = series - .cast(&DataType::Array( - Box::new(inner.map(|dt| dt.0).unwrap_or(*list_inner.clone())), - width, - )) - .map_err(PyPolarsErr::from)?; - Ok(series.into()) - }, - _ => Err(PyValueError::new_err("could not create Array from input")), - } + return if let Some(inner) = inner { + let series = Series::from_any_values_and_dtype( + name, + val.as_ref(), + &DataType::Array(Box::new(inner.0), width), + true, + ) + .map_err(PyPolarsErr::from)?; + Ok(series.into()) + } else { + let series = Series::new(name, &val); + match series.dtype() { + DataType::List(list_inner) => { + let series = series + .cast(&DataType::Array( + Box::new(inner.map(|dt| dt.0).unwrap_or(*list_inner.clone())), + width, + )) + .map_err(PyPolarsErr::from)?; + Ok(series.into()) + }, + _ => Err(PyValueError::new_err("could not create Array from input")), + } + }; } } diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index cd564a990aef..5efeec63d20b 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -1,10 +1,11 @@ from __future__ import annotations +import datetime import io import json from collections import OrderedDict from io import BytesIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -289,6 +290,72 @@ def test_write_json_duration() -> None: ) +@pytest.mark.parametrize( + ("data", "dtype"), + [ + ([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), width=3)), + ([["a", "b"], [None, None]], pl.Array(pl.Utf8, width=2)), + ([[True, False, None], [None, None, None]], pl.Array(pl.Utf8, width=3)), + ( + [[[1, 2, 3], [4, None]], None, [[None, None, 2]]], + pl.List(pl.Array(pl.Int32(), width=3)), + ), + ( + [ + [datetime.datetime(1991, 1, 1), datetime.datetime(1991, 1, 1), None], + [None, None, None], + ], + pl.Array(pl.Datetime, width=3), + ), + ], +) +def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None: + df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) + buf = io.StringIO() + df.write_json(buf) + buf.seek(0) + deserialized_df = pl.read_json(buf) + assert_frame_equal(deserialized_df, df) + + +@pytest.mark.parametrize( + ("data", "dtype"), + [ + ( + [ + [ + datetime.datetime(1997, 10, 1), + datetime.datetime(2000, 1, 2, 10, 30, 1), + ], + [None, None], + ], + pl.Array(pl.Datetime, width=2), + ), + ( + [[datetime.date(1997, 10, 1), datetime.date(2000, 1, 1)], [None, None]], + pl.Array(pl.Date, width=2), + ), + ( + [ + [datetime.timedelta(seconds=1), datetime.timedelta(seconds=10)], + [None, None], + ], + pl.Array(pl.Duration, width=2), + ), + ], +) +def test_write_read_json_array_logical_inner_type( + data: Any, dtype: pl.DataType +) -> None: + df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) + buf = io.StringIO() + df.write_json(buf) + buf.seek(0) + deserialized_df = pl.read_json(buf) + assert deserialized_df.dtypes == df.dtypes + assert deserialized_df.to_dict(as_series=False) == df.to_dict(as_series=False) + + def test_json_null_infer() -> None: json = BytesIO( bytes( diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 0183a2c272f6..08d7cca15f5a 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -189,3 +189,17 @@ def test_serde_categorical_series_10586() -> None: def test_serde_keep_dtype_empty_list() -> None: s = pl.Series([{"a": None}], dtype=pl.Struct([pl.Field("a", pl.List(pl.Utf8))])) assert s.dtype == pickle.loads(pickle.dumps(s)).dtype + + +def test_serde_array_dtype() -> None: + s = pl.Series( + [[1, 2, 3], [None, None, None], [1, None, 3]], + dtype=pl.Array(pl.Int32(), width=3), + ) + assert_series_equal(pickle.loads(pickle.dumps(s)), s) + + nested_s = pl.Series( + [[[1, 2, 3], [4, None]], None, [[None, None, 2]]], + dtype=pl.List(pl.Array(pl.Int32(), width=3)), + ) + assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s)