Skip to content

Commit

Permalink
fix: Serialize categories of Enum in arrow metadata (#20181)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 6, 2024
1 parent a1bfeef commit abbad69
Show file tree
Hide file tree
Showing 34 changed files with 188 additions and 96 deletions.
19 changes: 17 additions & 2 deletions crates/polars-arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::sync::Arc;

use polars_utils::pl_str::PlSmallStr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use super::{ArrowDataType, Metadata};

pub static DTYPE_ENUM_VALUES: &str = "_PL_ENUM_VALUES";

/// Represents Arrow's metadata of a "column".
///
/// A [`Field`] is the closest representation of the traditional "column": a logical type
Expand All @@ -22,7 +26,7 @@ pub struct Field {
/// Its nullability
pub is_nullable: bool,
/// Additional custom (opaque) metadata.
pub metadata: Metadata,
pub metadata: Option<Arc<Metadata>>,
}

/// Support for `ArrowSchema::from_iter([field, ..])`
Expand All @@ -46,11 +50,14 @@ impl Field {
/// Creates a new [`Field`] with metadata.
#[inline]
pub fn with_metadata(self, metadata: Metadata) -> Self {
if metadata.is_empty() {
return self;
}
Self {
name: self.name,
dtype: self.dtype,
is_nullable: self.is_nullable,
metadata,
metadata: Some(Arc::new(metadata)),
}
}

Expand All @@ -59,4 +66,12 @@ impl Field {
pub fn dtype(&self) -> &ArrowDataType {
&self.dtype
}

pub fn is_enum(&self) -> bool {
if let Some(md) = &self.metadata {
md.get(DTYPE_ENUM_VALUES).is_some()
} else {
false
}
}
}
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod schema;
use std::collections::BTreeMap;
use std::sync::Arc;

pub use field::Field;
pub use field::{Field, DTYPE_ENUM_VALUES};
pub use physical_type::*;
use polars_utils::pl_str::PlSmallStr;
pub use schema::{ArrowSchema, ArrowSchemaRef};
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-arrow/src/ffi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ impl ArrowSchema {
None
};

let metadata = &field.metadata;
let metadata = field
.metadata
.as_ref()
.map(|inner| (**inner).clone())
.unwrap_or_default();

let metadata = if let ArrowDataType::Extension(name, _, extension_metadata) = field.dtype()
{
Expand All @@ -102,7 +106,7 @@ impl ArrowSchema {

Some(metadata_to_bytes(&metadata))
} else if !metadata.is_empty() {
Some(metadata_to_bytes(metadata))
Some(metadata_to_bytes(&metadata))
} else {
None
};
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-arrow/src/io/ipc/read/schema.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use arrow_format::ipc::planus::ReadAsRoot;
use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};
use polars_error::{polars_bail, polars_err, PolarsResult};
Expand Down Expand Up @@ -27,7 +29,7 @@ fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(
fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {
let metadata = read_metadata(&ipc_field)?;

let extension = get_extension(&metadata);
let extension = metadata.as_ref().and_then(get_extension);

let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;

Expand All @@ -39,13 +41,13 @@ fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Fi
),
dtype,
is_nullable: ipc_field.nullable()?,
metadata,
metadata: metadata.map(Arc::new),
};

Ok((field, ipc_field_))
}

fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Metadata> {
fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {
Ok(if let Some(list) = field.custom_metadata()? {
let mut metadata_map = Metadata::new();
for kv in list {
Expand All @@ -54,9 +56,9 @@ fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Metadata>
metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));
}
}
metadata_map
Some(metadata_map)
} else {
Metadata::default()
None
})
}

Expand Down
4 changes: 3 additions & 1 deletion crates/polars-arrow/src/io/ipc/write/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_form
None
};

write_metadata(&field.metadata, &mut kv_vec);
if let Some(metadata) = &field.metadata {
write_metadata(metadata, &mut kv_vec);
}

let custom_metadata = if !kv_vec.is_empty() {
Some(kv_vec)
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl ArrayChunked {
self.name().clone(),
vec![(*arr.values()).clone()],
&field.dtype,
Some(&field.metadata),
field.metadata.as_deref(),
)
.unwrap()
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<'a> AnonymousListBuilder<'a> {
let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap();

let list_dtype_logical = match inner_dtype {
None => DataType::from(arr.dtype()),
None => DataType::from_arrow_dtype(arr.dtype()),
Some(dt) => DataType::List(Box::new(dt)),
};

Expand Down Expand Up @@ -147,7 +147,7 @@ impl ListBuilderTrait for AnonymousOwnedListBuilder {
let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap();

let list_dtype_logical = match inner_dtype {
None => DataType::from_arrow(arr.dtype(), false),
None => DataType::from_arrow_dtype(arr.dtype()),
Some(dt) => DataType::List(Box::new(dt)),
};

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/builder/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<T: ViewType + ?Sized> BinViewChunkedBuilder<T> {
pub fn new(name: PlSmallStr, capacity: usize) -> Self {
Self {
chunk_builder: MutableBinaryViewArray::with_capacity(capacity),
field: Arc::new(Field::new(name, DataType::from(&T::DATA_TYPE))),
field: Arc::new(Field::new(name, DataType::from_arrow_dtype(&T::DATA_TYPE))),
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::*;
fn from_chunks_list_dtype(chunks: &mut Vec<ArrayRef>, dtype: DataType) -> DataType {
// ensure we don't get List<null>
let dtype = if let Some(arr) = chunks.get(0) {
arr.dtype().into()
DataType::from_arrow_dtype(arr.dtype())
} else {
dtype
};
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ pub(crate) mod test {
fn cast() {
let a = get_chunked_array();
let b = a.cast(&DataType::Int64).unwrap();
assert_eq!(b.dtype(), &ArrowDataType::Int64)
assert_eq!(b.dtype(), &DataType::Int64)
}

fn assert_slice_equal<T>(ca: &ChunkedArray<T>, eq: &[T::Native])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use polars_utils::format_pl_smallstr;
use crate::prelude::*;
use crate::PROCESS_ID;

pub const EXTENSION_NAME: &str = "POLARS_EXTENSION_TYPE";
static POLARS_ALLOW_EXTENSION: AtomicBool = AtomicBool::new(false);

/// Control whether extension types may be created.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,7 @@ mod test {
];

for (dt_a, dt_p) in dtypes {
let dt: DataType = (&dt_a).into();
let dt = DataType::from_arrow_dtype(&dt_a);

assert_eq!(dt_p, dt);
}
Expand Down
58 changes: 43 additions & 15 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::BTreeMap;

use arrow::datatypes::{Metadata, DTYPE_ENUM_VALUES};
#[cfg(feature = "dtype-array")]
use polars_utils::format_tuple;
use polars_utils::itertools::Itertools;
Expand All @@ -11,8 +12,32 @@ use crate::utils::materialize_dyn_int;

pub type TimeZone = PlSmallStr;

pub static DTYPE_ENUM_KEY: &str = "POLARS.CATEGORICAL_TYPE";
pub static DTYPE_ENUM_VALUE: &str = "ENUM";
static MAINTAIN_PL_TYPE: &str = "maintain_type";
static PL_KEY: &str = "pl";

pub trait MetaDataExt: IntoMetadata {
fn is_enum(&self) -> bool {
let metadata = self.into_metadata_ref();
metadata.get(DTYPE_ENUM_VALUES).is_some()
}

fn maintain_type(&self) -> bool {
let metadata = self.into_metadata_ref();
metadata.get(PL_KEY).map(|s| s.as_str()) == Some(MAINTAIN_PL_TYPE)
}
}

impl MetaDataExt for Metadata {}
pub trait IntoMetadata {
#[allow(clippy::wrong_self_convention)]
fn into_metadata_ref(&self) -> &Metadata;
}

impl IntoMetadata for Metadata {
fn into_metadata_ref(&self) -> &Metadata {
self
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
#[cfg_attr(
Expand Down Expand Up @@ -87,6 +112,7 @@ pub enum DataType {
// This is ignored with comparisons, hashing etc.
#[cfg(feature = "dtype-categorical")]
Categorical(Option<Arc<RevMapping>>, CategoricalOrdering),
// It is an Option, so that matching Enum/Categoricals can take the same guards.
#[cfg(feature = "dtype-categorical")]
Enum(Option<Arc<RevMapping>>, CategoricalOrdering),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -556,13 +582,22 @@ impl DataType {
pub fn to_arrow_field(&self, name: PlSmallStr, compat_level: CompatLevel) -> ArrowField {
let metadata = match self {
#[cfg(feature = "dtype-categorical")]
DataType::Enum(_, _) => Some(BTreeMap::from([(
DTYPE_ENUM_KEY.into(),
DTYPE_ENUM_VALUE.into(),
)])),
DataType::Enum(Some(revmap), _) => {
let cats = revmap.get_categories();
let mut encoded = String::with_capacity(cats.len() * 10);
for cat in cats.values_iter() {
encoded.push_str(itoa::Buffer::new().format(cat.len()));
encoded.push(';');
encoded.push_str(cat);
}
Some(BTreeMap::from([(
PlSmallStr::from_static(DTYPE_ENUM_VALUES),
PlSmallStr::from_string(encoded),
)]))
},
DataType::BinaryOffset => Some(BTreeMap::from([(
PlSmallStr::from_static("pl"),
PlSmallStr::from_static("maintain_type"),
PlSmallStr::from_static(PL_KEY),
PlSmallStr::from_static(MAINTAIN_PL_TYPE),
)])),
_ => None,
};
Expand Down Expand Up @@ -774,13 +809,6 @@ impl DataType {
}
}

impl PartialEq<ArrowDataType> for DataType {
fn eq(&self, other: &ArrowDataType) -> bool {
let dt: DataType = other.into();
self == &dt
}
}

impl Display for DataType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let s = match self {
Expand Down
53 changes: 41 additions & 12 deletions crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use arrow::datatypes::{Metadata, DTYPE_ENUM_VALUES};
use polars_utils::pl_str::PlSmallStr;

use super::*;
pub static EXTENSION_NAME: &str = "POLARS_EXTENSION_TYPE";

/// Characterizes the name and the [`DataType`] of a column.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -130,7 +132,15 @@ impl DataType {
Box::new(self)
}

pub fn from_arrow(dt: &ArrowDataType, bin_to_view: bool) -> DataType {
pub fn from_arrow_field(field: &ArrowField) -> DataType {
Self::from_arrow(&field.dtype, true, field.metadata.as_deref())
}

pub fn from_arrow_dtype(dt: &ArrowDataType) -> DataType {
Self::from_arrow(dt, true, None)
}

pub fn from_arrow(dt: &ArrowDataType, bin_to_view: bool, md: Option<&Metadata>) -> DataType {
match dt {
ArrowDataType::Null => DataType::Null,
ArrowDataType::UInt8 => DataType::UInt8,
Expand All @@ -145,15 +155,40 @@ impl DataType {
ArrowDataType::Float32 => DataType::Float32,
ArrowDataType::Float64 => DataType::Float64,
#[cfg(feature = "dtype-array")]
ArrowDataType::FixedSizeList(f, size) => DataType::Array(DataType::from_arrow(f.dtype(), bin_to_view).boxed(), *size),
ArrowDataType::LargeList(f) | ArrowDataType::List(f) => DataType::List(DataType::from_arrow(f.dtype(), bin_to_view).boxed()),
ArrowDataType::FixedSizeList(f, size) => DataType::Array(DataType::from_arrow_field(f).boxed(), *size),
ArrowDataType::LargeList(f) | ArrowDataType::List(f) => DataType::List(DataType::from_arrow_field(f).boxed()),
ArrowDataType::Date32 => DataType::Date,
ArrowDataType::Timestamp(tu, tz) => DataType::Datetime(tu.into(), DataType::canonical_timezone(tz)),
ArrowDataType::Duration(tu) => DataType::Duration(tu.into()),
ArrowDataType::Date64 => DataType::Datetime(TimeUnit::Milliseconds, None),
ArrowDataType::Time64(_) | ArrowDataType::Time32(_) => DataType::Time,
#[cfg(feature = "dtype-categorical")]
ArrowDataType::Dictionary(_, _, _) => DataType::Categorical(None,Default::default()),
ArrowDataType::Dictionary(_, _, _) => {
if md.map(|md| md.is_enum()).unwrap_or(false) {
let md = md.unwrap();
let encoded = md.get(DTYPE_ENUM_VALUES).unwrap();
let mut encoded = encoded.as_str();
let mut cats = MutableBinaryViewArray::<str>::new();

// Data is encoded as <len in ascii><sep ';'><payload>
// We know thus that len is only [0-9] and the first ';' doesn't belong to the
// payload.
while let Some(pos) = encoded.find(';') {
let (len, remainder) = encoded.split_at(pos);
// Split off ';'
encoded = &remainder[1..];
let len = len.parse::<usize>().unwrap();

let (value, remainder) = encoded.split_at(len);
cats.push_value(value);
encoded = remainder;
}
DataType::Enum(Some(Arc::new(RevMapping::build_local(cats.into()))), Default::default())

} else {
DataType::Categorical(None,Default::default())
}
},
#[cfg(feature = "dtype-struct")]
ArrowDataType::Struct(fields) => {
DataType::Struct(fields.iter().map(|fld| fld.into()).collect())
Expand All @@ -162,7 +197,7 @@ impl DataType {
ArrowDataType::Struct(_) => {
panic!("activate the 'dtype-struct' feature to handle struct data types")
}
ArrowDataType::Extension(name, _, _) if name.as_str() == "POLARS_EXTENSION_TYPE" => {
ArrowDataType::Extension(name, _, _) if name.as_str() == EXTENSION_NAME => {
#[cfg(feature = "object")]
{
DataType::Object("object", None)
Expand Down Expand Up @@ -190,14 +225,8 @@ impl DataType {
}
}

impl From<&ArrowDataType> for DataType {
fn from(dt: &ArrowDataType) -> Self {
Self::from_arrow(dt, true)
}
}

impl From<&ArrowField> for Field {
fn from(f: &ArrowField) -> Self {
Field::new(f.name.clone(), f.dtype().into())
Field::new(f.name.clone(), DataType::from_arrow_field(f))
}
}
Loading

0 comments on commit abbad69

Please sign in to comment.