Skip to content

Commit

Permalink
refactor(rust): Fix and extend AnyValue comparison (#18534)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 4, 2024
1 parent f39f1c7 commit b89e772
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 42 deletions.
252 changes: 211 additions & 41 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use arrow::legacy::trusted_len::TrustedLenPush;
use arrow::types::PrimitiveType;
use polars_utils::format_pl_smallstr;
use polars_utils::itertools::Itertools;
#[cfg(feature = "dtype-struct")]
use polars_utils::slice::GetSaferUnchecked;
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -66,27 +67,26 @@ pub enum AnyValue<'a> {
/// A 64-bit time representing the elapsed time since midnight in nanoseconds
#[cfg(feature = "dtype-time")]
Time(i64),
#[cfg(feature = "dtype-categorical")]
// If syncptr is_null the data is in the rev-map
// otherwise it is in the array pointer
#[cfg(feature = "dtype-categorical")]
Categorical(u32, &'a RevMapping, SyncPtr<Utf8ViewArray>),
#[cfg(feature = "dtype-categorical")]
Enum(u32, &'a RevMapping, SyncPtr<Utf8ViewArray>),
/// Nested type, contains arrays that are filled with one of the datatypes.
List(Series),
#[cfg(feature = "dtype-array")]
Array(Series, usize),
#[cfg(feature = "object")]
/// Can be used to fmt and implements Any, so can be downcasted to the proper value type.
#[cfg(feature = "object")]
Object(&'a dyn PolarsObjectSafe),
#[cfg(feature = "object")]
ObjectOwned(OwnedObject),
#[cfg(feature = "dtype-struct")]
// 3 pointers and thus not larger than string/vec
// - idx in the `&StructArray`
// - The array itself
// - The fields
#[cfg(feature = "dtype-struct")]
Struct(usize, &'a StructArray, &'a [Field]),
#[cfg(feature = "dtype-struct")]
StructOwned(Box<(Vec<AnyValue<'a>>, Vec<Field>)>),
Expand Down Expand Up @@ -940,6 +940,23 @@ impl AnyValue<'_> {
pub fn eq_missing(&self, other: &Self, null_equal: bool) -> bool {
use AnyValue::*;
match (self, other) {
// Map to borrowed.
(StringOwned(l), r) => AnyValue::String(l.as_str()) == *r,
(BinaryOwned(l), r) => AnyValue::Binary(l.as_slice()) == *r,
#[cfg(feature = "object")]
(ObjectOwned(l), r) => AnyValue::Object(&*l.0) == *r,
(l, StringOwned(r)) => *l == AnyValue::String(r.as_str()),
(l, BinaryOwned(r)) => *l == AnyValue::Binary(r.as_slice()),
#[cfg(feature = "object")]
(l, ObjectOwned(r)) => *l == AnyValue::Object(&*r.0),

// Comparison with null.
(Null, Null) => null_equal,
(Null, _) => false,
(_, Null) => false,

// Equality between equal types.
(Boolean(l), Boolean(r)) => *l == *r,
(UInt8(l), UInt8(r)) => *l == *r,
(UInt16(l), UInt16(r)) => *l == *r,
(UInt32(l), UInt32(r)) => *l == *r,
Expand All @@ -951,15 +968,7 @@ impl AnyValue<'_> {
(Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(),
(Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(),
(String(l), String(r)) => l == r,
(String(l), StringOwned(r)) => *l == r.as_str(),
(StringOwned(l), String(r)) => l.as_str() == *r,
(StringOwned(l), StringOwned(r)) => l == r,
(Boolean(l), Boolean(r)) => *l == *r,
(Binary(l), Binary(r)) => l == r,
(BinaryOwned(l), BinaryOwned(r)) => l == r,
(Binary(l), BinaryOwned(r)) => l == r,
(BinaryOwned(l), Binary(r)) => l == r,
(Null, Null) => null_equal,
#[cfg(feature = "dtype-time")]
(Time(l), Time(r)) => *l == *r,
#[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
Expand All @@ -970,47 +979,81 @@ impl AnyValue<'_> {
},
(List(l), List(r)) => l == r,
#[cfg(feature = "dtype-categorical")]
(Categorical(idx_l, rev_l, _), Categorical(idx_r, rev_r, _)) => match (rev_l, rev_r) {
(RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => {
id_l == id_r && idx_l == idx_r
},
(RevMapping::Local(_, id_l), RevMapping::Local(_, id_r)) => {
id_l == id_r && idx_l == idx_r
},
_ => false,
(Categorical(idx_l, rev_l, ptr_l), Categorical(idx_r, rev_r, ptr_r)) => {
if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) {
// We can't support this because our Hash impl directly hashes the index. If you
// add support for this we must change the Hash impl.
unimplemented!(
"comparing categoricals with different revmaps is not supported"
);
}

idx_l == idx_r
},
#[cfg(feature = "dtype-categorical")]
(Enum(idx_l, _, _), Enum(idx_r, _, _)) => idx_l == idx_r,
(Enum(idx_l, rev_l, ptr_l), Enum(idx_r, rev_r, ptr_r)) => {
// We can't support this because our Hash impl directly hashes the index. If you
// add support for this we must change the Hash impl.
if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) {
unimplemented!("comparing enums with different revmaps is not supported");
}

idx_l == idx_r
},
#[cfg(feature = "dtype-duration")]
(Duration(l, tu_l), Duration(r, tu_r)) => l == r && tu_l == tu_r,
#[cfg(feature = "dtype-struct")]
(StructOwned(l), StructOwned(r)) => {
let l = &*l.0;
let r = &*r.0;
l == r
let l_av = &*l.0;
let r_av = &*r.0;
l_av == r_av
},
// TODO! add structowned with idx and arced structarray
#[cfg(feature = "dtype-struct")]
(StructOwned(l), Struct(idx, arr, fields)) => {
let fields_left = &*l.0;
let avs = struct_to_avs_static(*idx, arr, fields);
fields_left == avs
l.0.iter()
.eq_by_(struct_av_iter(*idx, arr, fields), |lv, rv| *lv == rv)
},
#[cfg(feature = "dtype-struct")]
(Struct(idx, arr, fields), StructOwned(r)) => {
let fields_right = &*r.0;
let avs = struct_to_avs_static(*idx, arr, fields);
fields_right == avs
struct_av_iter(*idx, arr, fields).eq_by_(r.0.iter(), |lv, rv| lv == *rv)
},
#[cfg(feature = "dtype-struct")]
(Struct(l_idx, l_arr, l_fields), Struct(r_idx, r_arr, r_fields)) => {
struct_av_iter(*l_idx, l_arr, l_fields).eq(struct_av_iter(*r_idx, r_arr, r_fields))
},
#[cfg(feature = "dtype-decimal")]
(Decimal(v_l, scale_l), Decimal(v_r, scale_r)) => {
// Decimal equality here requires that both value and scale be equal, eg
// 1.2 at scale 1, and 1.20 at scale 2, are not equal.
*v_l == *v_r && *scale_l == *scale_r
(Decimal(l_v, l_s), Decimal(r_v, r_s)) => {
// l_v / 10**l_s == r_v / 10**r_s
if l_s == r_s && l_v == r_v || *l_v == 0 && *r_v == 0 {
true
} else if l_s < r_s {
// l_v * 10**(r_s - l_s) == r_v
if let Some(lhs) = (|| {
let exp = i128::checked_pow(10, (r_s - l_s).try_into().ok()?)?;
l_v.checked_mul(exp)
})() {
lhs == *r_v
} else {
false
}
} else {
// l_v == r_v * 10**(l_s - r_s)
if let Some(rhs) = (|| {
let exp = i128::checked_pow(10, (l_s - r_s).try_into().ok()?)?;
r_v.checked_mul(exp)
})() {
*l_v == rhs
} else {
false
}
}
},
#[cfg(feature = "object")]
(Object(l), Object(r)) => l == r,
_ => false,

(_, _) => {
unimplemented!("ordering for mixed dtypes is not supported")
},
}
}
}
Expand All @@ -1027,6 +1070,23 @@ impl PartialOrd for AnyValue<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use AnyValue::*;
match (self, &other) {
// Map to borrowed.
(StringOwned(l), r) => AnyValue::String(l.as_str()).partial_cmp(r),
(BinaryOwned(l), r) => AnyValue::Binary(l.as_slice()).partial_cmp(r),
#[cfg(feature = "object")]
(ObjectOwned(l), r) => AnyValue::Object(&*l.0).partial_cmp(r),
(l, StringOwned(r)) => l.partial_cmp(&AnyValue::String(r.as_str())),
(l, BinaryOwned(r)) => l.partial_cmp(&AnyValue::Binary(r.as_slice())),
#[cfg(feature = "object")]
(l, ObjectOwned(r)) => l.partial_cmp(&AnyValue::Object(&*r.0)),

// Comparison with null.
(Null, Null) => Some(Ordering::Equal),
(Null, _) => Some(Ordering::Less),
(_, Null) => Some(Ordering::Greater),

// Comparison between equal types.
(Boolean(l), Boolean(r)) => l.partial_cmp(r),
(UInt8(l), UInt8(r)) => l.partial_cmp(r),
(UInt16(l), UInt16(r)) => l.partial_cmp(r),
(UInt32(l), UInt32(r)) => l.partial_cmp(r),
Expand All @@ -1035,12 +1095,90 @@ impl PartialOrd for AnyValue<'_> {
(Int16(l), Int16(r)) => l.partial_cmp(r),
(Int32(l), Int32(r)) => l.partial_cmp(r),
(Int64(l), Int64(r)) => l.partial_cmp(r),
(Float32(l), Float32(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()),
(Float64(l), Float64(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()),
_ => match (self.as_borrowed(), other.as_borrowed()) {
(String(l), String(r)) => l.partial_cmp(r),
(Binary(l), Binary(r)) => l.partial_cmp(r),
_ => None,
(Float32(l), Float32(r)) => Some(l.tot_cmp(r)),
(Float64(l), Float64(r)) => Some(l.tot_cmp(r)),
(String(l), String(r)) => l.partial_cmp(r),
(Binary(l), Binary(r)) => l.partial_cmp(r),
#[cfg(feature = "dtype-date")]
(Date(l), Date(r)) => l.partial_cmp(r),
#[cfg(feature = "dtype-datetime")]
(Datetime(lt, lu, lz), Datetime(rt, ru, rz)) => {
if lu != ru || lz != rz {
unimplemented!(
"comparing datetimes with different units or timezones is not supported"
);
}

lt.partial_cmp(rt)
},
#[cfg(feature = "dtype-duration")]
(Duration(lt, lu), Duration(rt, ru)) => {
if lu != ru {
unimplemented!("comparing durations with different units is not supported");
}

lt.partial_cmp(rt)
},
#[cfg(feature = "dtype-time")]
(Time(l), Time(r)) => l.partial_cmp(r),
#[cfg(feature = "dtype-categorical")]
(Categorical(..), Categorical(..)) => {
unimplemented!(
"can't order categoricals as AnyValues, dtype for ordering is needed"
)
},
#[cfg(feature = "dtype-categorical")]
(Enum(..), Enum(..)) => {
unimplemented!("can't order enums as AnyValues, dtype for ordering is needed")
},
(List(_), List(_)) => {
unimplemented!("ordering for List dtype is not supported")
},
#[cfg(feature = "dtype-array")]
(Array(..), Array(..)) => {
unimplemented!("ordering for Array dtype is not supported")
},
#[cfg(feature = "object")]
(Object(_), Object(_)) => {
unimplemented!("ordering for Object dtype is not supported")
},
#[cfg(feature = "dtype-struct")]
(StructOwned(_), StructOwned(_))
| (StructOwned(_), Struct(..))
| (Struct(..), StructOwned(_))
| (Struct(..), Struct(..)) => {
unimplemented!("ordering for Struct dtype is not supported")
},
#[cfg(feature = "dtype-decimal")]
(Decimal(l_v, l_s), Decimal(r_v, r_s)) => {
// l_v / 10**l_s <=> r_v / 10**r_s
if l_s == r_s && l_v == r_v || *l_v == 0 && *r_v == 0 {
Some(Ordering::Equal)
} else if l_s < r_s {
// l_v * 10**(r_s - l_s) <=> r_v
if let Some(lhs) = (|| {
let exp = i128::checked_pow(10, (r_s - l_s).try_into().ok()?)?;
l_v.checked_mul(exp)
})() {
lhs.partial_cmp(r_v)
} else {
Some(Ordering::Greater)
}
} else {
// l_v <=> r_v * 10**(l_s - r_s)
if let Some(rhs) = (|| {
let exp = i128::checked_pow(10, (l_s - r_s).try_into().ok()?)?;
r_v.checked_mul(exp)
})() {
l_v.partial_cmp(&rhs)
} else {
Some(Ordering::Less)
}
}
},

(_, _) => {
unimplemented!("ordering for mixed dtypes is not supported")
},
}
}
Expand Down Expand Up @@ -1069,6 +1207,38 @@ fn struct_to_avs_static(idx: usize, arr: &StructArray, fields: &[Field]) -> Vec<
avs
}

#[cfg(feature = "dtype-categorical")]
fn same_revmap(
rev_l: &RevMapping,
ptr_l: SyncPtr<Utf8ViewArray>,
rev_r: &RevMapping,
ptr_r: SyncPtr<Utf8ViewArray>,
) -> bool {
if ptr_l.is_null() && ptr_r.is_null() {
match (rev_l, rev_r) {
(RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => id_l == id_r,
(RevMapping::Local(_, id_l), RevMapping::Local(_, id_r)) => id_l == id_r,
_ => false,
}
} else {
ptr_l == ptr_r
}
}

#[cfg(feature = "dtype-struct")]
fn struct_av_iter<'a>(
idx: usize,
arr: &'a StructArray,
fields: &'a [Field],
) -> impl Iterator<Item = AnyValue<'a>> {
let arrs = arr.values();
(0..arrs.len()).map(move |i| unsafe {
let arr = &**arrs.get_unchecked_release(i);
let field = fields.get_unchecked_release(i);
arr_to_any_value(arr, idx, &field.dtype)
})
}

pub trait GetAnyValue {
/// # Safety
///
Expand Down
Loading

0 comments on commit b89e772

Please sign in to comment.