From 30b91df25639892ea021e6bb17fc8983ecd086cc Mon Sep 17 00:00:00 2001 From: Xiangpeng Hao Date: Sat, 15 Jun 2024 11:28:04 -0400 Subject: [PATCH 1/4] add view types --- datafusion/common/src/scalar/mod.rs | 89 +++++++++++++++---- datafusion/functions/src/core/arrow_cast.rs | 4 + .../proto/datafusion_common.proto | 4 + datafusion/proto-common/src/from_proto/mod.rs | 4 + .../proto-common/src/generated/pbjson.rs | 55 ++++++++++++ .../proto-common/src/generated/prost.rs | 12 ++- datafusion/proto-common/src/to_proto/mod.rs | 10 +++ datafusion/sql/src/unparser/expr.rs | 8 ++ 8 files changed, 169 insertions(+), 17 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8073b21cdde0..96bf4216d9a1 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -221,10 +221,14 @@ pub enum ScalarValue { UInt64(Option), /// utf-8 encoded string. Utf8(Option), + /// utf-8 encoded string but from view types. + Utf8View(Option), /// utf-8 encoded string representing a LargeString's arrow type. LargeUtf8(Option), /// binary Binary(Option>), + /// binary but from view types. + BinaryView(Option>), /// fixed size binary FixedSizeBinary(i32, Option>), /// large binary @@ -345,10 +349,14 @@ impl PartialEq for ScalarValue { (UInt64(_), _) => false, (Utf8(v1), Utf8(v2)) => v1.eq(v2), (Utf8(_), _) => false, + (Utf8View(v1), Utf8View(v2)) => v1.eq(v2), + (Utf8View(_), _) => false, (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), (LargeUtf8(_), _) => false, (Binary(v1), Binary(v2)) => v1.eq(v2), (Binary(_), _) => false, + (BinaryView(v1), BinaryView(v2)) => v1.eq(v2), + (BinaryView(_), _) => false, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.eq(v2), (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), @@ -470,8 +478,12 @@ impl PartialOrd for ScalarValue { (Utf8(_), _) => None, (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), (LargeUtf8(_), _) => None, + (Utf8View(v1), Utf8View(v2)) => v1.partial_cmp(v2), + (Utf8View(_), _) => None, (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), (Binary(_), _) => None, + (BinaryView(v1), BinaryView(v2)) => v1.partial_cmp(v2), + (BinaryView(_), _) => None, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.partial_cmp(v2), (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), @@ -667,11 +679,10 @@ impl std::hash::Hash for ScalarValue { UInt16(v) => v.hash(state), UInt32(v) => v.hash(state), UInt64(v) => v.hash(state), - Utf8(v) => v.hash(state), - LargeUtf8(v) => v.hash(state), - Binary(v) => v.hash(state), - FixedSizeBinary(_, v) => v.hash(state), - LargeBinary(v) => v.hash(state), + Utf8(v) | LargeUtf8(v) | Utf8View(v) => v.hash(state), + Binary(v) | FixedSizeBinary(_, v) | LargeBinary(v) | BinaryView(v) => { + v.hash(state) + } List(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } @@ -1107,7 +1118,9 @@ impl ScalarValue { ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::Utf8View(_) => DataType::Utf8View, ScalarValue::Binary(_) => DataType::Binary, + ScalarValue::BinaryView(_) => DataType::BinaryView, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, ScalarValue::List(arr) => arr.data_type().to_owned(), @@ -1310,11 +1323,13 @@ impl ScalarValue { ScalarValue::UInt16(v) => v.is_none(), ScalarValue::UInt32(v) => v.is_none(), ScalarValue::UInt64(v) => v.is_none(), - ScalarValue::Utf8(v) => v.is_none(), - ScalarValue::LargeUtf8(v) => v.is_none(), - ScalarValue::Binary(v) => v.is_none(), - ScalarValue::FixedSizeBinary(_, v) => v.is_none(), - ScalarValue::LargeBinary(v) => v.is_none(), + ScalarValue::Utf8(v) + | ScalarValue::Utf8View(v) + | ScalarValue::LargeUtf8(v) => v.is_none(), + ScalarValue::Binary(v) + | ScalarValue::BinaryView(v) + | ScalarValue::FixedSizeBinary(_, v) + | ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. ScalarValue::List(arr) => arr.len() == arr.null_count(), @@ -2002,6 +2017,12 @@ impl ScalarValue { } None => new_null_array(&DataType::Utf8, size), }, + ScalarValue::Utf8View(e) => match e { + Some(value) => { + Arc::new(StringViewArray::from_iter_values(repeat(value).take(size))) + } + None => new_null_array(&DataType::Utf8View, size), + }, ScalarValue::LargeUtf8(e) => match e { Some(value) => { Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) @@ -2018,6 +2039,16 @@ impl ScalarValue { Arc::new(repeat(None::<&str>).take(size).collect::()) } }, + ScalarValue::BinaryView(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::(), + ), + None => { + Arc::new(repeat(None::<&str>).take(size).collect::()) + } + }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( FixedSizeBinaryArray::try_from_sparse_iter_with_size( @@ -2361,10 +2392,14 @@ impl ScalarValue { DataType::LargeBinary => { typed_cast!(array, index, LargeBinaryArray, LargeBinary)? } + DataType::BinaryView => { + typed_cast!(array, index, BinaryViewArray, BinaryView)? + } DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, DataType::LargeUtf8 => { typed_cast!(array, index, LargeStringArray, LargeUtf8)? } + DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, DataType::List(_) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); @@ -2652,12 +2687,18 @@ impl ScalarValue { ScalarValue::Utf8(val) => { eq_array_primitive!(array, index, StringArray, val)? } + ScalarValue::Utf8View(val) => { + eq_array_primitive!(array, index, StringViewArray, val)? + } ScalarValue::LargeUtf8(val) => { eq_array_primitive!(array, index, LargeStringArray, val)? } ScalarValue::Binary(val) => { eq_array_primitive!(array, index, BinaryArray, val)? } + ScalarValue::BinaryView(val) => { + eq_array_primitive!(array, index, BinaryViewArray, val)? + } ScalarValue::FixedSizeBinary(_, val) => { eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } @@ -2790,7 +2831,9 @@ impl ScalarValue { | ScalarValue::DurationMillisecond(_) | ScalarValue::DurationMicrosecond(_) | ScalarValue::DurationNanosecond(_) => 0, - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => { + ScalarValue::Utf8(s) + | ScalarValue::LargeUtf8(s) + | ScalarValue::Utf8View(s) => { s.as_ref().map(|s| s.capacity()).unwrap_or_default() } ScalarValue::TimestampSecond(_, s) @@ -2801,7 +2844,8 @@ impl ScalarValue { } ScalarValue::Binary(b) | ScalarValue::FixedSizeBinary(_, b) - | ScalarValue::LargeBinary(b) => { + | ScalarValue::LargeBinary(b) + | ScalarValue::BinaryView(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } ScalarValue::List(arr) => arr.get_array_memory_size(), @@ -3068,7 +3112,9 @@ impl TryFrom<&DataType> for ScalarValue { } DataType::Utf8 => ScalarValue::Utf8(None), DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Utf8View => ScalarValue::Utf8View(None), DataType::Binary => ScalarValue::Binary(None), + DataType::BinaryView => ScalarValue::BinaryView(None), DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), DataType::LargeBinary => ScalarValue::LargeBinary(None), DataType::Date32 => ScalarValue::Date32(None), @@ -3190,11 +3236,13 @@ impl fmt::Display for ScalarValue { ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, - ScalarValue::Utf8(e) => format_option!(f, e)?, - ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::Utf8(e) + | ScalarValue::LargeUtf8(e) + | ScalarValue::Utf8View(e) => format_option!(f, e)?, ScalarValue::Binary(e) | ScalarValue::FixedSizeBinary(_, e) - | ScalarValue::LargeBinary(e) => match e { + | ScalarValue::LargeBinary(e) + | ScalarValue::BinaryView(e) => match e { Some(l) => write!( f, "{}", @@ -3318,10 +3366,14 @@ impl fmt::Debug for ScalarValue { } ScalarValue::Utf8(None) => write!(f, "Utf8({self})"), ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{self}\")"), + ScalarValue::Utf8View(None) => write!(f, "Utf8View({self})"), + ScalarValue::Utf8View(Some(_)) => write!(f, "Utf8View(\"{self}\")"), ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({self})"), ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{self}\")"), ScalarValue::Binary(None) => write!(f, "Binary({self})"), ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{self}\")"), + ScalarValue::BinaryView(None) => write!(f, "BinaryView({self})"), + ScalarValue::BinaryView(Some(_)) => write!(f, "BinaryView(\"{self}\")"), ScalarValue::FixedSizeBinary(size, None) => { write!(f, "FixedSizeBinary({size}, {self})") } @@ -5393,6 +5445,13 @@ mod tests { ScalarValue::Utf8(None), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); + + check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); + check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); + check_scalar_cast( + ScalarValue::from("larger than 12 bytes string"), + DataType::Utf8View, + ); } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index d641389e0ae3..9c410d4e18e8 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -564,7 +564,9 @@ impl<'a> Tokenizer<'a> { "Utf8" => Token::SimpleType(DataType::Utf8), "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8), + "Utf8View" => Token::SimpleType(DataType::Utf8View), "Binary" => Token::SimpleType(DataType::Binary), + "BinaryView" => Token::SimpleType(DataType::BinaryView), "LargeBinary" => Token::SimpleType(DataType::LargeBinary), "Float16" => Token::SimpleType(DataType::Float16), @@ -772,11 +774,13 @@ mod test { DataType::Interval(IntervalUnit::DayTime), DataType::Interval(IntervalUnit::MonthDayNano), DataType::Binary, + DataType::BinaryView, DataType::FixedSizeBinary(0), DataType::FixedSizeBinary(1234), DataType::FixedSizeBinary(-432), DataType::LargeBinary, DataType::Utf8, + DataType::Utf8View, DataType::LargeUtf8, DataType::Decimal128(7, 12), DataType::Decimal256(6, 13), diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 29a348283f46..e523ef1a5e93 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -248,6 +248,7 @@ message ScalarValue{ bool bool_value = 1; string utf8_value = 2; string large_utf8_value = 3; + string utf8_view_value = 23; int32 int8_value = 4; int32 int16_value = 5; int32 int32_value = 6; @@ -281,6 +282,7 @@ message ScalarValue{ ScalarDictionaryValue dictionary_value = 27; bytes binary_value = 28; bytes large_binary_value = 29; + bytes binary_view_value = 22; ScalarTime64Value time64_value = 30; IntervalDayTimeValue interval_daytime_value = 25; IntervalMonthDayNanoValue interval_month_day_nano = 31; @@ -318,8 +320,10 @@ message ArrowType{ EmptyMessage FLOAT32 = 12 ; EmptyMessage FLOAT64 = 13 ; EmptyMessage UTF8 = 14 ; + EmptyMessage UTF8_VIEW = 35; EmptyMessage LARGE_UTF8 = 32; EmptyMessage BINARY = 15 ; + EmptyMessage BINARY_VIEW = 34; int32 FIXED_SIZE_BINARY = 16 ; EmptyMessage LARGE_BINARY = 31; EmptyMessage DATE32 = 17 ; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 25c1502ee75b..be87123fb13f 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -224,8 +224,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, + arrow_type::ArrowTypeEnum::Utf8View(_) => DataType::Utf8View, arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, + arrow_type::ArrowTypeEnum::BinaryView(_) => DataType::BinaryView, arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { DataType::FixedSizeBinary(*size) } @@ -361,6 +363,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Ok(match value { Value::BoolValue(v) => Self::Boolean(Some(*v)), Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())), + Value::Utf8ViewValue(v) => Self::Utf8View(Some(v.to_owned())), Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())), Value::Int8Value(v) => Self::Int8(Some(*v as i8)), Value::Int16Value(v) => Self::Int16(Some(*v as i16)), @@ -571,6 +574,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Self::Dictionary(Box::new(index_type), Box::new(value)) } Value::BinaryValue(v) => Self::Binary(Some(v.clone())), + Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some( IntervalDayTimeType::make_value(v.days, v.milliseconds), diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 6f8409b82afe..ead29d9b92e0 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -125,12 +125,18 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Utf8(v) => { struct_ser.serialize_field("UTF8", v)?; } + arrow_type::ArrowTypeEnum::Utf8View(v) => { + struct_ser.serialize_field("UTF8VIEW", v)?; + } arrow_type::ArrowTypeEnum::LargeUtf8(v) => { struct_ser.serialize_field("LARGEUTF8", v)?; } arrow_type::ArrowTypeEnum::Binary(v) => { struct_ser.serialize_field("BINARY", v)?; } + arrow_type::ArrowTypeEnum::BinaryView(v) => { + struct_ser.serialize_field("BINARYVIEW", v)?; + } arrow_type::ArrowTypeEnum::FixedSizeBinary(v) => { struct_ser.serialize_field("FIXEDSIZEBINARY", v)?; } @@ -216,9 +222,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "FLOAT32", "FLOAT64", "UTF8", + "UTF8_VIEW", + "UTF8VIEW", "LARGE_UTF8", "LARGEUTF8", "BINARY", + "BINARY_VIEW", + "BINARYVIEW", "FIXED_SIZE_BINARY", "FIXEDSIZEBINARY", "LARGE_BINARY", @@ -258,8 +268,10 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Float32, Float64, Utf8, + Utf8View, LargeUtf8, Binary, + BinaryView, FixedSizeBinary, LargeBinary, Date32, @@ -312,8 +324,10 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "FLOAT32" => Ok(GeneratedField::Float32), "FLOAT64" => Ok(GeneratedField::Float64), "UTF8" => Ok(GeneratedField::Utf8), + "UTF8VIEW" | "UTF8_VIEW" => Ok(GeneratedField::Utf8View), "LARGEUTF8" | "LARGE_UTF8" => Ok(GeneratedField::LargeUtf8), "BINARY" => Ok(GeneratedField::Binary), + "BINARYVIEW" | "BINARY_VIEW" => Ok(GeneratedField::BinaryView), "FIXEDSIZEBINARY" | "FIXED_SIZE_BINARY" => Ok(GeneratedField::FixedSizeBinary), "LARGEBINARY" | "LARGE_BINARY" => Ok(GeneratedField::LargeBinary), "DATE32" => Ok(GeneratedField::Date32), @@ -449,6 +463,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("UTF8")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8) +; + } + GeneratedField::Utf8View => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UTF8VIEW")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8View) ; } GeneratedField::LargeUtf8 => { @@ -463,6 +484,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("BINARY")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Binary) +; + } + GeneratedField::BinaryView => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("BINARYVIEW")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::BinaryView) ; } GeneratedField::FixedSizeBinary => { @@ -6255,6 +6283,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::LargeUtf8Value(v) => { struct_ser.serialize_field("largeUtf8Value", v)?; } + scalar_value::Value::Utf8ViewValue(v) => { + struct_ser.serialize_field("utf8ViewValue", v)?; + } scalar_value::Value::Int8Value(v) => { struct_ser.serialize_field("int8Value", v)?; } @@ -6348,6 +6379,10 @@ impl serde::Serialize for ScalarValue { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("largeBinaryValue", pbjson::private::base64::encode(&v).as_str())?; } + scalar_value::Value::BinaryViewValue(v) => { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("binaryViewValue", pbjson::private::base64::encode(&v).as_str())?; + } scalar_value::Value::Time64Value(v) => { struct_ser.serialize_field("time64Value", v)?; } @@ -6383,6 +6418,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "utf8Value", "large_utf8_value", "largeUtf8Value", + "utf8_view_value", + "utf8ViewValue", "int8_value", "int8Value", "int16_value", @@ -6439,6 +6476,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "binaryValue", "large_binary_value", "largeBinaryValue", + "binary_view_value", + "binaryViewValue", "time64_value", "time64Value", "interval_daytime_value", @@ -6457,6 +6496,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { BoolValue, Utf8Value, LargeUtf8Value, + Utf8ViewValue, Int8Value, Int16Value, Int32Value, @@ -6485,6 +6525,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { DictionaryValue, BinaryValue, LargeBinaryValue, + BinaryViewValue, Time64Value, IntervalDaytimeValue, IntervalMonthDayNano, @@ -6515,6 +6556,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "boolValue" | "bool_value" => Ok(GeneratedField::BoolValue), "utf8Value" | "utf8_value" => Ok(GeneratedField::Utf8Value), "largeUtf8Value" | "large_utf8_value" => Ok(GeneratedField::LargeUtf8Value), + "utf8ViewValue" | "utf8_view_value" => Ok(GeneratedField::Utf8ViewValue), "int8Value" | "int8_value" => Ok(GeneratedField::Int8Value), "int16Value" | "int16_value" => Ok(GeneratedField::Int16Value), "int32Value" | "int32_value" => Ok(GeneratedField::Int32Value), @@ -6543,6 +6585,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "dictionaryValue" | "dictionary_value" => Ok(GeneratedField::DictionaryValue), "binaryValue" | "binary_value" => Ok(GeneratedField::BinaryValue), "largeBinaryValue" | "large_binary_value" => Ok(GeneratedField::LargeBinaryValue), + "binaryViewValue" | "binary_view_value" => Ok(GeneratedField::BinaryViewValue), "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), @@ -6595,6 +6638,12 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeUtf8Value); } + GeneratedField::Utf8ViewValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("utf8ViewValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Utf8ViewValue); + } GeneratedField::Int8Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("int8Value")); @@ -6772,6 +6821,12 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { } value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::LargeBinaryValue(x.0)); } + GeneratedField::BinaryViewValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryViewValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::BinaryViewValue(x.0)); + } GeneratedField::Time64Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time64Value")); diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index ff17a40738b5..b306f3212a2f 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -326,7 +326,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -345,6 +345,8 @@ pub mod scalar_value { Utf8Value(::prost::alloc::string::String), #[prost(string, tag = "3")] LargeUtf8Value(::prost::alloc::string::String), + #[prost(string, tag = "23")] + Utf8ViewValue(::prost::alloc::string::String), #[prost(int32, tag = "4")] Int8Value(i32), #[prost(int32, tag = "5")] @@ -402,6 +404,8 @@ pub mod scalar_value { BinaryValue(::prost::alloc::vec::Vec), #[prost(bytes, tag = "29")] LargeBinaryValue(::prost::alloc::vec::Vec), + #[prost(bytes, tag = "22")] + BinaryViewValue(::prost::alloc::vec::Vec), #[prost(message, tag = "30")] Time64Value(super::ScalarTime64Value), #[prost(message, tag = "25")] @@ -440,7 +444,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 32, 15, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -482,10 +486,14 @@ pub mod arrow_type { Float64(super::EmptyMessage), #[prost(message, tag = "14")] Utf8(super::EmptyMessage), + #[prost(message, tag = "35")] + Utf8View(super::EmptyMessage), #[prost(message, tag = "32")] LargeUtf8(super::EmptyMessage), #[prost(message, tag = "15")] Binary(super::EmptyMessage), + #[prost(message, tag = "34")] + BinaryView(super::EmptyMessage), #[prost(int32, tag = "16")] FixedSizeBinary(i32), #[prost(message, tag = "31")] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 8e7ee9a7d6fa..a3dc826a79ca 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -347,6 +347,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } + ScalarValue::Utf8View(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Utf8ViewValue(s.to_owned()) + }) + } ScalarValue::List(arr) => { encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } @@ -461,6 +466,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::BinaryValue(s.to_owned()) }) } + ScalarValue::BinaryView(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::BinaryViewValue(s.to_owned()) + }) + } ScalarValue::LargeBinary(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::LargeBinaryValue(s.to_owned()) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 12c48054f1a7..1d197f3a0d8a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -709,12 +709,20 @@ impl Unparser<'_> { ast::Value::SingleQuotedString(str.to_string()), )), ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Utf8View(Some(str)) => Ok(ast::Expr::Value( + ast::Value::SingleQuotedString(str.to_string()), + )), + ScalarValue::Utf8View(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( ast::Value::SingleQuotedString(str.to_string()), )), ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::BinaryView(Some(_)) => { + not_impl_err!("Unsupported scalar: {v:?}") + } + ScalarValue::BinaryView(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeBinary(..) => { not_impl_err!("Unsupported scalar: {v:?}") } From 2c894404900daacaa6c913590fdcd1bca472bb24 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 17 Jun 2024 14:35:28 -0400 Subject: [PATCH 2/4] Add slt tests --- datafusion/sqllogictest/test_files/arrow_typeof.slt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index c928b96e0321..ab4ff9e2ce92 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -422,3 +422,13 @@ query ? select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)'); ---- [1, 2, 3] + +# Tests for Utf8View +query ?T +select arrow_cast('MyAwesomeString', 'Utf8View'), arrow_typeof(arrow_cast('MyAwesomeString', 'Utf8View')) +---- +MyAwesomeString Utf8View + +# Fails until we update arrow-rs with support for https://github.com/apache/arrow-rs/pull/5894 +query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: arrow_cast"\) +arrow_cast('MyAwesomeString', 'BinaryView'), arrow_typeof(arrow_cast('MyAwesomeString', 'BinaryView')) From 3309c7291bbb43b636deb036c5ca86f3bcbaea07 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 17 Jun 2024 14:45:30 -0400 Subject: [PATCH 3/4] comment out failing test --- datafusion/common/src/scalar/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 96bf4216d9a1..3daf347ae4ff 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -5446,12 +5446,16 @@ mod tests { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); + // needs https://github.com/apache/arrow-rs/issues/5893 + /* check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); check_scalar_cast( ScalarValue::from("larger than 12 bytes string"), DataType::Utf8View, ); + + */ } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` From 6da8f0e3e82e641cddd447c3e4162eebfa317657 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 17 Jun 2024 14:48:28 -0400 Subject: [PATCH 4/4] update vendored code --- .../proto/src/generated/datafusion_proto_common.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index ff17a40738b5..b306f3212a2f 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -326,7 +326,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -345,6 +345,8 @@ pub mod scalar_value { Utf8Value(::prost::alloc::string::String), #[prost(string, tag = "3")] LargeUtf8Value(::prost::alloc::string::String), + #[prost(string, tag = "23")] + Utf8ViewValue(::prost::alloc::string::String), #[prost(int32, tag = "4")] Int8Value(i32), #[prost(int32, tag = "5")] @@ -402,6 +404,8 @@ pub mod scalar_value { BinaryValue(::prost::alloc::vec::Vec), #[prost(bytes, tag = "29")] LargeBinaryValue(::prost::alloc::vec::Vec), + #[prost(bytes, tag = "22")] + BinaryViewValue(::prost::alloc::vec::Vec), #[prost(message, tag = "30")] Time64Value(super::ScalarTime64Value), #[prost(message, tag = "25")] @@ -440,7 +444,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 32, 15, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -482,10 +486,14 @@ pub mod arrow_type { Float64(super::EmptyMessage), #[prost(message, tag = "14")] Utf8(super::EmptyMessage), + #[prost(message, tag = "35")] + Utf8View(super::EmptyMessage), #[prost(message, tag = "32")] LargeUtf8(super::EmptyMessage), #[prost(message, tag = "15")] Binary(super::EmptyMessage), + #[prost(message, tag = "34")] + BinaryView(super::EmptyMessage), #[prost(int32, tag = "16")] FixedSizeBinary(i32), #[prost(message, tag = "31")]