Skip to content

Commit

Permalink
feat: Add dynamic literals to ensure schema correctness (#15832)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 24, 2024
1 parent 3f58c8f commit 71495eb
Show file tree
Hide file tree
Showing 44 changed files with 1,038 additions and 804 deletions.
6 changes: 3 additions & 3 deletions crates/polars-core/src/datatypes/_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ enum SerializableDataType {
#[cfg(feature = "dtype-struct")]
Struct(Vec<Field>),
// some logical types we cannot know statically, e.g. Datetime
Unknown,
Unknown(UnknownKind),
#[cfg(feature = "dtype-categorical")]
Categorical(Option<Wrap<Utf8ViewArray>>, CategoricalOrdering),
#[cfg(feature = "dtype-decimal")]
Expand Down Expand Up @@ -141,7 +141,7 @@ impl From<&DataType> for SerializableDataType {
#[cfg(feature = "dtype-array")]
Array(dt, width) => Self::Array(Box::new(dt.as_ref().into()), *width),
Null => Self::Null,
Unknown => Self::Unknown,
Unknown(kind) => Self::Unknown(*kind),
#[cfg(feature = "dtype-struct")]
Struct(flds) => Self::Struct(flds.clone()),
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -185,7 +185,7 @@ impl From<SerializableDataType> for DataType {
#[cfg(feature = "dtype-array")]
Array(dt, width) => Self::Array(Box::new((*dt).into()), width),
Null => Self::Null,
Unknown => Self::Unknown,
Unknown(kind) => Self::Unknown(kind),
#[cfg(feature = "dtype-struct")]
Struct(flds) => Self::Struct(flds),
#[cfg(feature = "dtype-categorical")]
Expand Down
72 changes: 65 additions & 7 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,22 @@ pub type TimeZone = String;
pub static DTYPE_ENUM_KEY: &str = "POLARS.CATEGORICAL_TYPE";
pub static DTYPE_ENUM_VALUE: &str = "ENUM";

#[derive(Clone, Debug, Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
#[cfg_attr(
any(feature = "serde", feature = "serde-lazy"),
derive(Serialize, Deserialize)
)]
pub enum UnknownKind {
// Hold the value to determine the concrete size.
Int(i128),
Float,
// Can be Categorical or String
Str,
#[default]
Any,
}

#[derive(Clone, Debug)]
pub enum DataType {
Boolean,
UInt8,
Expand Down Expand Up @@ -59,8 +74,13 @@ pub enum DataType {
#[cfg(feature = "dtype-struct")]
Struct(Vec<Field>),
// some logical types we cannot know statically, e.g. Datetime
#[default]
Unknown,
Unknown(UnknownKind),
}

impl Default for DataType {
fn default() -> Self {
DataType::Unknown(UnknownKind::Any)
}
}

pub trait AsRefDataType {
Expand Down Expand Up @@ -144,7 +164,7 @@ impl DataType {
DataType::List(inner) => inner.is_known(),
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => fields.iter().all(|fld| fld.dtype.is_known()),
DataType::Unknown => false,
DataType::Unknown(_) => false,
_ => true,
}
}
Expand Down Expand Up @@ -208,7 +228,14 @@ impl DataType {

/// Check if this [`DataType`] is a basic numeric type (excludes Decimal).
pub fn is_numeric(&self) -> bool {
self.is_float() || self.is_integer()
self.is_float() || self.is_integer() || self.is_dynamic()
}

pub fn is_dynamic(&self) -> bool {
matches!(
self,
DataType::Unknown(UnknownKind::Int(_) | UnknownKind::Float | UnknownKind::Str)
)
}

/// Check if this [`DataType`] is a boolean
Expand Down Expand Up @@ -382,6 +409,32 @@ impl DataType {
}
}

pub fn is_string(&self) -> bool {
matches!(self, DataType::String | DataType::Unknown(UnknownKind::Str))
}

pub fn is_categorical(&self) -> bool {
#[cfg(feature = "dtype-categorical")]
{
matches!(self, DataType::Categorical(_, _))
}
#[cfg(not(feature = "dtype-categorical"))]
{
false
}
}

pub fn is_enum(&self) -> bool {
#[cfg(feature = "dtype-categorical")]
{
matches!(self, DataType::Enum(_, _))
}
#[cfg(not(feature = "dtype-categorical"))]
{
false
}
}

/// Convert to an Arrow Field
pub fn to_arrow_field(&self, name: &str, pl_flavor: bool) -> ArrowField {
let metadata = match self {
Expand Down Expand Up @@ -490,7 +543,7 @@ impl DataType {
Ok(ArrowDataType::Struct(fields))
},
BinaryOffset => Ok(ArrowDataType::LargeBinary),
Unknown => Ok(ArrowDataType::Unknown),
Unknown(_) => Ok(ArrowDataType::Unknown),
}
}

Expand Down Expand Up @@ -591,7 +644,12 @@ impl Display for DataType {
DataType::Enum(_, _) => "enum",
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => return write!(f, "struct[{}]", fields.len()),
DataType::Unknown => "unknown",
DataType::Unknown(kind) => match kind {
UnknownKind::Any => "unknown",
UnknownKind::Int(_) => "dyn int",
UnknownKind::Float => "dyn float",
UnknownKind::Str => "dyn str",
},
DataType::BinaryOffset => "binary[offset]",
};
f.write_str(s)
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ impl Field {
}
}

impl AsRef<DataType> for Field {
fn as_ref(&self) -> &DataType {
&self.dtype
}
}

impl AsRef<DataType> for DataType {
fn as_ref(&self) -> &DataType {
self
}
}

impl DataType {
pub fn boxed(self) -> Box<DataType> {
Box::new(self)
Expand Down
23 changes: 15 additions & 8 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ macro_rules! impl_polars_num_datatype {
};
}

macro_rules! impl_polars_datatype {
($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
macro_rules! impl_polars_datatype2 {
($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
#[derive(Clone, Copy)]
pub struct $ca {}

Expand All @@ -128,12 +128,18 @@ macro_rules! impl_polars_datatype {

#[inline]
fn get_dtype() -> DataType {
DataType::$variant
$dtype
}
}
};
}

macro_rules! impl_polars_datatype {
($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
impl_polars_datatype2!($ca, DataType::$variant, $arr, $lt, $phys, $zerophys);
};
}

impl_polars_num_datatype!(PolarsIntegerType, UInt8Type, UInt8, u8);
impl_polars_num_datatype!(PolarsIntegerType, UInt16Type, UInt16, u16);
impl_polars_num_datatype!(PolarsIntegerType, UInt32Type, UInt32, u32);
Expand All @@ -145,17 +151,18 @@ impl_polars_num_datatype!(PolarsIntegerType, Int64Type, Int64, i64);
impl_polars_num_datatype!(PolarsFloatType, Float32Type, Float32, f32);
impl_polars_num_datatype!(PolarsFloatType, Float64Type, Float64, f64);
impl_polars_datatype!(DateType, Date, PrimitiveArray<i32>, 'a, i32, i32);
#[cfg(feature = "dtype-decimal")]
impl_polars_datatype!(DecimalType, Unknown, PrimitiveArray<i128>, 'a, i128, i128);
impl_polars_datatype!(DatetimeType, Unknown, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(DurationType, Unknown, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(CategoricalType, Unknown, PrimitiveArray<u32>, 'a, u32, u32);
impl_polars_datatype!(TimeType, Time, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(StringType, String, Utf8ViewArray, 'a, &'a str, Option<&'a str>);
impl_polars_datatype!(BinaryType, Binary, BinaryViewArray, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_datatype!(BinaryOffsetType, BinaryOffset, BinaryArray<i64>, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool);

#[cfg(feature = "dtype-decimal")]
impl_polars_datatype2!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i128>, 'a, i128, i128);
impl_polars_datatype2!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype2!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype2!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<u32>, 'a, u32, u32);

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ListType {}
unsafe impl PolarsDataType for ListType {
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ impl Series {
}
},
Null => new_null(name, &chunks),
Unknown => panic!("uh oh, somehow we don't know the dtype?"),
Unknown(_) => {
panic!("dtype is unknown; consider supplying data-types for all operations")
},
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
Expand Down
87 changes: 84 additions & 3 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use num_traits::Signed;

use super::*;

/// Given two data types, determine the data type that both types can safely be cast to.
Expand Down Expand Up @@ -195,9 +197,9 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
(Time, Float64) => Some(Float64),

// every known type can be casted to a string except binary
(dt, String) if dt != &DataType::Unknown && dt != &DataType::Binary => Some(String),
(dt, String) if !matches!(dt, DataType::Unknown(UnknownKind::Any)) && dt != &DataType::Binary => Some(String),

(dt, String) if dt != &DataType::Unknown => Some(String),
(dt, String) if !matches!(dt, DataType::Unknown(UnknownKind::Any)) => Some(String),

(dt, Null) => Some(dt.clone()),

Expand Down Expand Up @@ -253,7 +255,35 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
}
(_, Unknown) => Some(Unknown),
#[cfg(feature = "dtype-struct")]
(Struct(inner), right @ Unknown(UnknownKind::Float | UnknownKind::Int(_))) => {
match inner.first() {
Some(inner) => get_supertype(&inner.dtype, right),
None => None
}
},
(dt, Unknown(kind)) => {
match kind {
UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() | dt.is_string() => Some(dt.clone()),
UnknownKind::Float if dt.is_numeric() => Some(Unknown(UnknownKind::Float)),
UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()),
#[cfg(feature = "dtype-categorical")]
UnknownKind::Str if dt.is_categorical() => {
let Categorical(_, ord) = dt else { unreachable!()};
Some(Categorical(None, *ord))
},
dynam if dt.is_null() => Some(Unknown(*dynam)),
UnknownKind::Int(v) if dt.is_numeric() => {
let smallest_fitting_dtype = if dt.is_unsigned_integer() && v.is_positive() {
materialize_dyn_int_pos(*v).dtype()
} else {
materialize_smallest_dyn_int(*v).dtype()
};
get_supertype(dt, &smallest_fitting_dtype)
}
_ => Some(Unknown(UnknownKind::Any))
}
},
#[cfg(feature = "dtype-struct")]
(Struct(fields_a), Struct(fields_b)) => {
super_type_structs(fields_a, fields_b)
Expand Down Expand Up @@ -341,3 +371,54 @@ fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType
Some(DataType::Struct(new_fields))
}
}

pub fn materialize_dyn_int(v: i128) -> AnyValue<'static> {
// Try to get the "smallest" fitting value.
// TODO! next breaking go to true smallest.
match i32::try_from(v).ok() {
Some(v) => AnyValue::Int32(v),
None => match i64::try_from(v).ok() {
Some(v) => AnyValue::Int64(v),
None => match u64::try_from(v).ok() {
Some(v) => AnyValue::UInt64(v),
None => AnyValue::Null,
},
},
}
}
fn materialize_dyn_int_pos(v: i128) -> AnyValue<'static> {
// Try to get the "smallest" fitting value.
// TODO! next breaking go to true smallest.
match u8::try_from(v).ok() {
Some(v) => AnyValue::UInt8(v),
None => match u16::try_from(v).ok() {
Some(v) => AnyValue::UInt16(v),
None => match u32::try_from(v).ok() {
Some(v) => AnyValue::UInt32(v),
None => match u64::try_from(v).ok() {
Some(v) => AnyValue::UInt64(v),
None => AnyValue::Null,
},
},
},
}
}

fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> {
match i8::try_from(v).ok() {
Some(v) => AnyValue::Int8(v),
None => match i16::try_from(v).ok() {
Some(v) => AnyValue::Int16(v),
None => match i32::try_from(v).ok() {
Some(v) => AnyValue::Int32(v),
None => match i64::try_from(v).ok() {
Some(v) => AnyValue::Int64(v),
None => match u64::try_from(v).ok() {
Some(v) => AnyValue::UInt64(v),
None => AnyValue::Null,
},
},
},
},
}
}
6 changes: 6 additions & 0 deletions crates/polars-lazy/src/physical_plan/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ impl PhysicalExpr for LiteralExpr {
.into_time()
.into_series(),
Series(series) => series.deref().clone(),
lv @ (Int(_) | Float(_) | StrCat(_)) => polars_core::prelude::Series::from_any_values(
LITERAL_NAME,
&[lv.to_any_value().unwrap()],
false,
)
.unwrap(),
};
Ok(s)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ pub(crate) fn insert_streaming_nodes(
.iter()
.all(|fld| allowed_dtype(fld.data_type(), string_cache)),
// We need to be able to sink to disk or produce the aggregate return dtype.
DataType::Unknown => false,
DataType::Unknown(_) => false,
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(_, _) => false,
_ => true,
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/optimization_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ pub fn test_predicate_block_cast() -> PolarsResult<()> {
let lf1 = df
.clone()
.lazy()
.with_column(col("value").cast(DataType::Int16) * lit(0.1f32))
.with_column(col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32))
.filter(col("value").lt(lit(2.5f32)));

let lf2 = df
.lazy()
.select([col("value").cast(DataType::Int16) * lit(0.1f32)])
.select([col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32)])
.filter(col("value").lt(lit(2.5f32)));

for lf in [lf1, lf2] {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ fn test_simplify_expr() {

let plan = df
.lazy()
.select(&[lit(1.0f32) + lit(1.0f32) + col("sepal_width")])
.select(&[lit(1.0) + lit(1.0) + col("sepal_width")])
.logical_plan;

let mut expr_arena = Arena::new();
Expand All @@ -564,7 +564,7 @@ fn test_simplify_expr() {
.unwrap();
let plan = node_to_lp(lp_top, &expr_arena, &mut lp_arena);
assert!(
matches!(plan, DslPlan::Select{ expr, ..} if matches!(&expr[0], Expr::BinaryExpr{left, ..} if **left == Expr::Literal(LiteralValue::Float32(2.0))))
matches!(plan, DslPlan::Select{ expr, ..} if matches!(&expr[0], Expr::BinaryExpr{left, ..} if **left == Expr::Literal(LiteralValue::Float(2.0))))
);
}

Expand Down
Loading

0 comments on commit 71495eb

Please sign in to comment.