Skip to content

Commit

Permalink
Add initial support for Utf8View and BinaryView types (#10925)
Browse files Browse the repository at this point in the history
* add view types

* Add slt tests

* comment out failing test

* update vendored code

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
XiangpengHao and alamb authored Jun 17, 2024
1 parent 1cb0057 commit f373a86
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 19 deletions.
93 changes: 78 additions & 15 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,14 @@ pub enum ScalarValue {
UInt64(Option<u64>),
/// utf-8 encoded string.
Utf8(Option<String>),
/// utf-8 encoded string but from view types.
Utf8View(Option<String>),
/// utf-8 encoded string representing a LargeString's arrow type.
LargeUtf8(Option<String>),
/// binary
Binary(Option<Vec<u8>>),
/// binary but from view types.
BinaryView(Option<Vec<u8>>),
/// fixed size binary
FixedSizeBinary(i32, Option<Vec<u8>>),
/// large binary
Expand Down Expand Up @@ -345,10 +349,14 @@ impl PartialEq for ScalarValue {
(UInt64(_), _) => false,
(Utf8(v1), Utf8(v2)) => v1.eq(v2),
(Utf8(_), _) => false,
(Utf8View(v1), Utf8View(v2)) => v1.eq(v2),
(Utf8View(_), _) => false,
(LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2),
(LargeUtf8(_), _) => false,
(Binary(v1), Binary(v2)) => v1.eq(v2),
(Binary(_), _) => false,
(BinaryView(v1), BinaryView(v2)) => v1.eq(v2),
(BinaryView(_), _) => false,
(FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.eq(v2),
(FixedSizeBinary(_, _), _) => false,
(LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2),
Expand Down Expand Up @@ -470,8 +478,12 @@ impl PartialOrd for ScalarValue {
(Utf8(_), _) => None,
(LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2),
(LargeUtf8(_), _) => None,
(Utf8View(v1), Utf8View(v2)) => v1.partial_cmp(v2),
(Utf8View(_), _) => None,
(Binary(v1), Binary(v2)) => v1.partial_cmp(v2),
(Binary(_), _) => None,
(BinaryView(v1), BinaryView(v2)) => v1.partial_cmp(v2),
(BinaryView(_), _) => None,
(FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.partial_cmp(v2),
(FixedSizeBinary(_, _), _) => None,
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
Expand Down Expand Up @@ -667,11 +679,10 @@ impl std::hash::Hash for ScalarValue {
UInt16(v) => v.hash(state),
UInt32(v) => v.hash(state),
UInt64(v) => v.hash(state),
Utf8(v) => v.hash(state),
LargeUtf8(v) => v.hash(state),
Binary(v) => v.hash(state),
FixedSizeBinary(_, v) => v.hash(state),
LargeBinary(v) => v.hash(state),
Utf8(v) | LargeUtf8(v) | Utf8View(v) => v.hash(state),
Binary(v) | FixedSizeBinary(_, v) | LargeBinary(v) | BinaryView(v) => {
v.hash(state)
}
List(arr) => {
hash_nested_array(arr.to_owned() as ArrayRef, state);
}
Expand Down Expand Up @@ -1107,7 +1118,9 @@ impl ScalarValue {
ScalarValue::Float64(_) => DataType::Float64,
ScalarValue::Utf8(_) => DataType::Utf8,
ScalarValue::LargeUtf8(_) => DataType::LargeUtf8,
ScalarValue::Utf8View(_) => DataType::Utf8View,
ScalarValue::Binary(_) => DataType::Binary,
ScalarValue::BinaryView(_) => DataType::BinaryView,
ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz),
ScalarValue::LargeBinary(_) => DataType::LargeBinary,
ScalarValue::List(arr) => arr.data_type().to_owned(),
Expand Down Expand Up @@ -1310,11 +1323,13 @@ impl ScalarValue {
ScalarValue::UInt16(v) => v.is_none(),
ScalarValue::UInt32(v) => v.is_none(),
ScalarValue::UInt64(v) => v.is_none(),
ScalarValue::Utf8(v) => v.is_none(),
ScalarValue::LargeUtf8(v) => v.is_none(),
ScalarValue::Binary(v) => v.is_none(),
ScalarValue::FixedSizeBinary(_, v) => v.is_none(),
ScalarValue::LargeBinary(v) => v.is_none(),
ScalarValue::Utf8(v)
| ScalarValue::Utf8View(v)
| ScalarValue::LargeUtf8(v) => v.is_none(),
ScalarValue::Binary(v)
| ScalarValue::BinaryView(v)
| ScalarValue::FixedSizeBinary(_, v)
| ScalarValue::LargeBinary(v) => v.is_none(),
// arr.len() should be 1 for a list scalar, but we don't seem to
// enforce that anywhere, so we still check against array length.
ScalarValue::List(arr) => arr.len() == arr.null_count(),
Expand Down Expand Up @@ -2002,6 +2017,12 @@ impl ScalarValue {
}
None => new_null_array(&DataType::Utf8, size),
},
ScalarValue::Utf8View(e) => match e {
Some(value) => {
Arc::new(StringViewArray::from_iter_values(repeat(value).take(size)))
}
None => new_null_array(&DataType::Utf8View, size),
},
ScalarValue::LargeUtf8(e) => match e {
Some(value) => {
Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size)))
Expand All @@ -2018,6 +2039,16 @@ impl ScalarValue {
Arc::new(repeat(None::<&str>).take(size).collect::<BinaryArray>())
}
},
ScalarValue::BinaryView(e) => match e {
Some(value) => Arc::new(
repeat(Some(value.as_slice()))
.take(size)
.collect::<BinaryViewArray>(),
),
None => {
Arc::new(repeat(None::<&str>).take(size).collect::<BinaryViewArray>())
}
},
ScalarValue::FixedSizeBinary(s, e) => match e {
Some(value) => Arc::new(
FixedSizeBinaryArray::try_from_sparse_iter_with_size(
Expand Down Expand Up @@ -2361,10 +2392,14 @@ impl ScalarValue {
DataType::LargeBinary => {
typed_cast!(array, index, LargeBinaryArray, LargeBinary)?
}
DataType::BinaryView => {
typed_cast!(array, index, BinaryViewArray, BinaryView)?
}
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?,
DataType::LargeUtf8 => {
typed_cast!(array, index, LargeStringArray, LargeUtf8)?
}
DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?,
DataType::List(_) => {
let list_array = array.as_list::<i32>();
let nested_array = list_array.value(index);
Expand Down Expand Up @@ -2652,12 +2687,18 @@ impl ScalarValue {
ScalarValue::Utf8(val) => {
eq_array_primitive!(array, index, StringArray, val)?
}
ScalarValue::Utf8View(val) => {
eq_array_primitive!(array, index, StringViewArray, val)?
}
ScalarValue::LargeUtf8(val) => {
eq_array_primitive!(array, index, LargeStringArray, val)?
}
ScalarValue::Binary(val) => {
eq_array_primitive!(array, index, BinaryArray, val)?
}
ScalarValue::BinaryView(val) => {
eq_array_primitive!(array, index, BinaryViewArray, val)?
}
ScalarValue::FixedSizeBinary(_, val) => {
eq_array_primitive!(array, index, FixedSizeBinaryArray, val)?
}
Expand Down Expand Up @@ -2790,7 +2831,9 @@ impl ScalarValue {
| ScalarValue::DurationMillisecond(_)
| ScalarValue::DurationMicrosecond(_)
| ScalarValue::DurationNanosecond(_) => 0,
ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => {
ScalarValue::Utf8(s)
| ScalarValue::LargeUtf8(s)
| ScalarValue::Utf8View(s) => {
s.as_ref().map(|s| s.capacity()).unwrap_or_default()
}
ScalarValue::TimestampSecond(_, s)
Expand All @@ -2801,7 +2844,8 @@ impl ScalarValue {
}
ScalarValue::Binary(b)
| ScalarValue::FixedSizeBinary(_, b)
| ScalarValue::LargeBinary(b) => {
| ScalarValue::LargeBinary(b)
| ScalarValue::BinaryView(b) => {
b.as_ref().map(|b| b.capacity()).unwrap_or_default()
}
ScalarValue::List(arr) => arr.get_array_memory_size(),
Expand Down Expand Up @@ -3068,7 +3112,9 @@ impl TryFrom<&DataType> for ScalarValue {
}
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
DataType::Utf8View => ScalarValue::Utf8View(None),
DataType::Binary => ScalarValue::Binary(None),
DataType::BinaryView => ScalarValue::BinaryView(None),
DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None),
DataType::LargeBinary => ScalarValue::LargeBinary(None),
DataType::Date32 => ScalarValue::Date32(None),
Expand Down Expand Up @@ -3190,11 +3236,13 @@ impl fmt::Display for ScalarValue {
ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?,
ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?,
ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?,
ScalarValue::Utf8(e) => format_option!(f, e)?,
ScalarValue::LargeUtf8(e) => format_option!(f, e)?,
ScalarValue::Utf8(e)
| ScalarValue::LargeUtf8(e)
| ScalarValue::Utf8View(e) => format_option!(f, e)?,
ScalarValue::Binary(e)
| ScalarValue::FixedSizeBinary(_, e)
| ScalarValue::LargeBinary(e) => match e {
| ScalarValue::LargeBinary(e)
| ScalarValue::BinaryView(e) => match e {
Some(l) => write!(
f,
"{}",
Expand Down Expand Up @@ -3318,10 +3366,14 @@ impl fmt::Debug for ScalarValue {
}
ScalarValue::Utf8(None) => write!(f, "Utf8({self})"),
ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{self}\")"),
ScalarValue::Utf8View(None) => write!(f, "Utf8View({self})"),
ScalarValue::Utf8View(Some(_)) => write!(f, "Utf8View(\"{self}\")"),
ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({self})"),
ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{self}\")"),
ScalarValue::Binary(None) => write!(f, "Binary({self})"),
ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{self}\")"),
ScalarValue::BinaryView(None) => write!(f, "BinaryView({self})"),
ScalarValue::BinaryView(Some(_)) => write!(f, "BinaryView(\"{self}\")"),
ScalarValue::FixedSizeBinary(size, None) => {
write!(f, "FixedSizeBinary({size}, {self})")
}
Expand Down Expand Up @@ -5393,6 +5445,17 @@ mod tests {
ScalarValue::Utf8(None),
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
);

// needs https://github.com/apache/arrow-rs/issues/5893
/*
check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View);
check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View);
check_scalar_cast(
ScalarValue::from("larger than 12 bytes string"),
DataType::Utf8View,
);
*/
}

// mimics how casting work on scalar values by `casting` `scalar` to `desired_type`
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,9 @@ impl<'a> Tokenizer<'a> {

"Utf8" => Token::SimpleType(DataType::Utf8),
"LargeUtf8" => Token::SimpleType(DataType::LargeUtf8),
"Utf8View" => Token::SimpleType(DataType::Utf8View),
"Binary" => Token::SimpleType(DataType::Binary),
"BinaryView" => Token::SimpleType(DataType::BinaryView),
"LargeBinary" => Token::SimpleType(DataType::LargeBinary),

"Float16" => Token::SimpleType(DataType::Float16),
Expand Down Expand Up @@ -772,11 +774,13 @@ mod test {
DataType::Interval(IntervalUnit::DayTime),
DataType::Interval(IntervalUnit::MonthDayNano),
DataType::Binary,
DataType::BinaryView,
DataType::FixedSizeBinary(0),
DataType::FixedSizeBinary(1234),
DataType::FixedSizeBinary(-432),
DataType::LargeBinary,
DataType::Utf8,
DataType::Utf8View,
DataType::LargeUtf8,
DataType::Decimal128(7, 12),
DataType::Decimal256(6, 13),
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ message ScalarValue{
bool bool_value = 1;
string utf8_value = 2;
string large_utf8_value = 3;
string utf8_view_value = 23;
int32 int8_value = 4;
int32 int16_value = 5;
int32 int32_value = 6;
Expand Down Expand Up @@ -281,6 +282,7 @@ message ScalarValue{
ScalarDictionaryValue dictionary_value = 27;
bytes binary_value = 28;
bytes large_binary_value = 29;
bytes binary_view_value = 22;
ScalarTime64Value time64_value = 30;
IntervalDayTimeValue interval_daytime_value = 25;
IntervalMonthDayNanoValue interval_month_day_nano = 31;
Expand Down Expand Up @@ -318,8 +320,10 @@ message ArrowType{
EmptyMessage FLOAT32 = 12 ;
EmptyMessage FLOAT64 = 13 ;
EmptyMessage UTF8 = 14 ;
EmptyMessage UTF8_VIEW = 35;
EmptyMessage LARGE_UTF8 = 32;
EmptyMessage BINARY = 15 ;
EmptyMessage BINARY_VIEW = 34;
int32 FIXED_SIZE_BINARY = 16 ;
EmptyMessage LARGE_BINARY = 31;
EmptyMessage DATE32 = 17 ;
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType {
arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
arrow_type::ArrowTypeEnum::Utf8View(_) => DataType::Utf8View,
arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
arrow_type::ArrowTypeEnum::BinaryView(_) => DataType::BinaryView,
arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => {
DataType::FixedSizeBinary(*size)
}
Expand Down Expand Up @@ -361,6 +363,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Ok(match value {
Value::BoolValue(v) => Self::Boolean(Some(*v)),
Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())),
Value::Utf8ViewValue(v) => Self::Utf8View(Some(v.to_owned())),
Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())),
Value::Int8Value(v) => Self::Int8(Some(*v as i8)),
Value::Int16Value(v) => Self::Int16(Some(*v as i16)),
Expand Down Expand Up @@ -571,6 +574,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Self::Dictionary(Box::new(index_type), Box::new(value))
}
Value::BinaryValue(v) => Self::Binary(Some(v.clone())),
Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())),
Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())),
Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some(
IntervalDayTimeType::make_value(v.days, v.milliseconds),
Expand Down
Loading

0 comments on commit f373a86

Please sign in to comment.