diff --git a/core/Cargo.lock b/core/Cargo.lock index 52f105591d..105bcaf7c9 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -637,6 +637,7 @@ dependencies = [ "thrift 0.17.0", "tokio", "tokio-stream", + "twox-hash", "unicode-segmentation", "zstd", ] @@ -2823,6 +2824,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", + "rand", "static_assertions", ] diff --git a/core/Cargo.toml b/core/Cargo.toml index ac565680ab..c1fc624b0f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -75,6 +75,7 @@ once_cell = "1.18.0" regex = "1.9.6" crc32fast = "1.3.2" simd-adler32 = "0.3.7" +twox-hash = "1.6.3" [build-dependencies] prost-build = "0.9.0" diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 8c5e1f3916..7c28721c9e 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -23,7 +23,7 @@ use std::{ sync::Arc, }; -use crate::execution::datafusion::spark_hash::create_hashes; +use crate::execution::datafusion::spark_hash::{create_hashes, create_xxhash64_hashes}; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, @@ -119,6 +119,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_murmur3_hash); make_comet_scalar_udf!("murmur3_hash", func, without data_type) } + "xxhash64" => { + let func = Arc::new(spark_xxhash64); + make_comet_scalar_udf!("xxhash64", func, without data_type) + } sha if sha2_functions.contains(&sha) => { // Spark requires hex string as the result of sha2 functions, we have to wrap the // result of digest functions as hex string @@ -672,6 +676,49 @@ fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} + #[inline] fn hex_encode>(data: T) -> String { let mut s = String::with_capacity(data.as_ref().len() * 2); diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index aa4269dd01..48b9e65e08 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -18,7 +18,8 @@ //! This includes utilities for hashing and murmur3 hashing. use arrow::datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; -use std::sync::Arc; +use std::{hash::Hasher, sync::Arc}; +use twox_hash::XxHash64; use datafusion::{ arrow::{ @@ -104,19 +105,45 @@ fn test_murmur3() { let _expected = vec![ 142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928, ]; + assert_eq!(_hashes, _expected) +} + +#[inline] +pub(crate) fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { + // TODO: Rewrite with a stateless hasher to reduce stack allocation? + let mut hasher = XxHash64::with_seed(seed); + hasher.write(data.as_ref()); + hasher.finish() +} + +#[test] +fn test_xxhash64() { + let _hashes = ["", "a", "ab", "abc", "abcd", "abcde"] + .into_iter() + .map(|s| spark_compatible_xxhash64(s.as_bytes(), 42) as i64) + .collect::>(); + let _expected = vec![ + -7444071767201028348, + -8582455328737087284, + 2710560539726725091, + 1423657621850124518, + -6810745876291105281, + -990457398947679591, + ]; + assert_eq!(_hashes, _expected); } macro_rules! hash_array { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash(&array.value(i), *hash); + *hash = $hash_method(&array.value(i), *hash); } } else { for (i, hash) in $hashes.iter_mut().enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash(&array.value(i), *hash); + *hash = $hash_method(&array.value(i), *hash); } } } @@ -124,18 +151,36 @@ macro_rules! hash_array { } macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); if array.null_count() == 0 { for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } else { for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + }; +} + +macro_rules! hash_array_primitive_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = + $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); } } } @@ -143,7 +188,7 @@ macro_rules! hash_array_primitive { } macro_rules! hash_array_primitive_float { - ($array_type:ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident) => { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -151,9 +196,9 @@ macro_rules! hash_array_primitive_float { for (hash, value) in $hashes.iter_mut().zip(values.iter()) { // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. if *value == 0.0 && value.is_sign_negative() { - *hash = spark_compatible_murmur3_hash((0 as $ty2).to_le_bytes(), *hash); + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); } else { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } } else { @@ -161,9 +206,9 @@ macro_rules! hash_array_primitive_float { if !array.is_null(i) { // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. if *value == 0.0 && value.is_sign_negative() { - *hash = spark_compatible_murmur3_hash((0 as $ty2).to_le_bytes(), *hash); + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); } else { - *hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *hash); + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); } } } @@ -172,17 +217,17 @@ macro_rules! hash_array_primitive_float { } macro_rules! hash_array_decimal { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash(array.value(i).to_le_bytes(), *hash); + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); } } else { for (i, hash) in $hashes.iter_mut().enumerate() { if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash(array.value(i).to_le_bytes(), *hash); + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); } } } @@ -218,6 +263,34 @@ fn create_hashes_dictionary( Ok(()) } +fn create_xxhash64_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u64], +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // Hash each dictionary value once, and then use that computed + // hash for each key value to avoid a potentially expensive + // redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + let mut dict_hashes = vec![064; dict_values.len()]; + create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; + + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + Ok(()) +} + /// Creates hash values for every row, based on the values in the /// columns. /// @@ -230,78 +303,171 @@ pub fn create_hashes<'a>( for col in arrays { match col.data_type() { DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - if array.null_count() == 0 { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash( - i32::from(array.value(i)).to_le_bytes(), - *hash, - ); - } - } - } + hash_array_primitive_boolean!( + BooleanArray, + col, + i32, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Int8 => { - hash_array_primitive!(Int8Array, col, i32, hashes_buffer); + hash_array_primitive!( + Int8Array, + col, + i32, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Int16 => { - hash_array_primitive!(Int16Array, col, i32, hashes_buffer); + hash_array_primitive!( + Int16Array, + col, + i32, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Int32 => { - hash_array_primitive!(Int32Array, col, i32, hashes_buffer); + hash_array_primitive!( + Int32Array, + col, + i32, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Int64 => { - hash_array_primitive!(Int64Array, col, i64, hashes_buffer); + hash_array_primitive!( + Int64Array, + col, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Float32 => { - hash_array_primitive_float!(Float32Array, col, f32, i32, hashes_buffer); + hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Float64 => { - hash_array_primitive_float!(Float64Array, col, f64, i64, hashes_buffer); + hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Timestamp(TimeUnit::Second, _) => { - hash_array_primitive!(TimestampSecondArray, col, i64, hashes_buffer); + hash_array_primitive!( + TimestampSecondArray, + col, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Timestamp(TimeUnit::Millisecond, _) => { - hash_array_primitive!(TimestampMillisecondArray, col, i64, hashes_buffer); + hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Timestamp(TimeUnit::Microsecond, _) => { - hash_array_primitive!(TimestampMicrosecondArray, col, i64, hashes_buffer); + hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Timestamp(TimeUnit::Nanosecond, _) => { - hash_array_primitive!(TimestampNanosecondArray, col, i64, hashes_buffer); + hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Date32 => { - hash_array_primitive!(Date32Array, col, i32, hashes_buffer); + hash_array_primitive!( + Date32Array, + col, + i32, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Date64 => { - hash_array_primitive!(Date64Array, col, i64, hashes_buffer); + hash_array_primitive!( + Date64Array, + col, + i64, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Utf8 => { - hash_array!(StringArray, col, hashes_buffer); + hash_array!( + StringArray, + col, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::LargeUtf8 => { - hash_array!(LargeStringArray, col, hashes_buffer); + hash_array!( + LargeStringArray, + col, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Binary => { - hash_array!(BinaryArray, col, hashes_buffer); + hash_array!( + BinaryArray, + col, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::LargeBinary => { - hash_array!(LargeBinaryArray, col, hashes_buffer); + hash_array!( + LargeBinaryArray, + col, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::FixedSizeBinary(_) => { - hash_array!(FixedSizeBinaryArray, col, hashes_buffer); + hash_array!( + FixedSizeBinaryArray, + col, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Decimal128(_, _) => { - hash_array_decimal!(Decimal128Array, col, hashes_buffer); + hash_array_decimal!( + Decimal128Array, + col, + hashes_buffer, + spark_compatible_murmur3_hash + ); } DataType::Dictionary(index_type, _) => match **index_type { DataType::Int8 => { @@ -347,6 +513,213 @@ pub fn create_hashes<'a>( Ok(hashes_buffer) } +pub fn create_xxhash64_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> { + for col in arrays { + match col.data_type() { + DataType::Boolean => { + hash_array_primitive_boolean!( + BooleanArray, + col, + i32, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Int8 => { + hash_array_primitive!( + Int8Array, + col, + i32, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Int16 => { + hash_array_primitive!( + Int16Array, + col, + i32, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Int32 => { + hash_array_primitive!( + Int32Array, + col, + i32, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Int64 => { + hash_array_primitive!( + Int64Array, + col, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Float32 => { + hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Float64 => { + hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Timestamp(TimeUnit::Second, _) => { + hash_array_primitive!( + TimestampSecondArray, + col, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Date32 => { + hash_array_primitive!( + Date32Array, + col, + i32, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Date64 => { + hash_array_primitive!( + Date64Array, + col, + i64, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Utf8 => { + hash_array!(StringArray, col, hashes_buffer, spark_compatible_xxhash64); + } + DataType::LargeUtf8 => { + hash_array!( + LargeStringArray, + col, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Binary => { + hash_array!(BinaryArray, col, hashes_buffer, spark_compatible_xxhash64); + } + DataType::LargeBinary => { + hash_array!( + LargeBinaryArray, + col, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::FixedSizeBinary(_) => { + hash_array!( + FixedSizeBinaryArray, + col, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Decimal128(_, _) => { + hash_array_decimal!( + Decimal128Array, + col, + hashes_buffer, + spark_compatible_xxhash64 + ); + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::Int16 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::Int32 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::Int64 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::UInt8 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::UInt16 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::UInt32 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + DataType::UInt64 => { + create_xxhash64_hashes_dictionary::(col, hashes_buffer)?; + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported dictionary type in hasher hashing: {}", + col.data_type(), + ))) + } + }, + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {}", + col.data_type() + ))); + } + } + } + Ok(hashes_buffer) +} + pub(crate) fn pmod(hash: u32, n: usize) -> usize { let hash = hash as i32; let n = n as i32; @@ -360,7 +733,7 @@ mod tests { use arrow::array::{Float32Array, Float64Array}; use std::sync::Arc; - use crate::execution::datafusion::spark_hash::{create_hashes, pmod}; + use crate::execution::datafusion::spark_hash::{create_hashes, create_xxhash64_hashes, pmod}; use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; macro_rules! test_hashes { @@ -370,6 +743,12 @@ mod tests { create_hashes(&[i], &mut hashes).unwrap(); assert_eq!(hashes, $expected); }; + (XXHash64, $ty:ty, $values:expr, $expected:expr) => { + let i = Arc::new(<$ty>::from($values)) as ArrayRef; + let mut hashes = vec![42u64; $values.len()]; + create_xxhash64_hashes(&[i], &mut hashes).unwrap(); + assert_eq!(hashes, $expected); + }; } #[test] @@ -385,6 +764,32 @@ mod tests { vec![Some(1), None, Some(-1), Some(i8::MAX), Some(i8::MIN)], vec![0xdea578e3, 42, 0xa0590e3d, 0x43b4d8ed, 0x422a1365] ); + // xxhash64 + test_hashes!( + XXHash64, + Int8Array, + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x77cc15d9f9f2cdc2, + 0x39bc22b9e94d81d0 + ] + ); + // xxhash64 with null input + test_hashes!( + XXHash64, + Int8Array, + vec![Some(1), None, Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![ + 0xa309b38455455929, + 42, + 0x1bfdda8861c06e45, + 0x77cc15d9f9f2cdc2, + 0x39bc22b9e94d81d0 + ] + ); } #[test] @@ -407,6 +812,41 @@ mod tests { ], vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 42, 0x07fb67e7, 0x2b1f0fc6] ); + + // xxhash64 + test_hashes!( + XXHash64, + Int32Array, + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x14f0ac009c21721c, + 0x1cc7cb8d034769cd + ] + ); + // xxhash64 with null input + test_hashes!( + XXHash64, + Int32Array, + vec![ + Some(1), + Some(0), + Some(-1), + None, + Some(i32::MAX), + Some(i32::MIN) + ], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 42, + 0x14f0ac009c21721c, + 0x1cc7cb8d034769cd + ] + ); } #[test] @@ -429,6 +869,42 @@ mod tests { ], vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 42, 0xa05b5d7b, 0xcd1e64fb] ); + + // xxhash64 + test_hashes!( + XXHash64, + Int64Array, + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![ + 0x9ed50fd59358d232, + 0xb71b47ebda15746c, + 0x358ae035bfb46fd2, + 0xd2f1c616ae7eb306, + 0x88608019c494c1f4 + ] + ); + + // xxhash64 with null input + test_hashes!( + XXHash64, + Int64Array, + vec![ + Some(1), + Some(0), + Some(-1), + None, + Some(i64::MAX), + Some(i64::MIN) + ], + vec![ + 0x9ed50fd59358d232, + 0xb71b47ebda15746c, + 0x358ae035bfb46fd2, + 42, + 0xd2f1c616ae7eb306, + 0x88608019c494c1f4 + ] + ); } #[test] @@ -459,6 +935,52 @@ mod tests { ], vec![0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 42, 0xcbdc340f, 0xc0361c86] ); + + // xxhash64 + test_hashes!( + XXHash64, + Float32Array, + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999) + ], + vec![ + 0x9b92689757fcdbd, + 0x3229fbc4681e48f3, + 0x3229fbc4681e48f3, + 0xa2becc0e61bb3823, + 0x8f20ab82d4f3687f, + 0xdce4982d97f7ac4 + ] + ); + + // xxhash64 with null input + test_hashes!( + XXHash64, + Float32Array, + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + None, + Some(99999999999.99999999999), + Some(-99999999999.99999999999) + ], + vec![ + 0x9b92689757fcdbd, + 0x3229fbc4681e48f3, + 0x3229fbc4681e48f3, + 0xa2becc0e61bb3823, + 42, + 0x8f20ab82d4f3687f, + 0xdce4982d97f7ac4 + ] + ); } #[test] @@ -489,6 +1011,52 @@ mod tests { ], vec![0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 42, 0xb87e1595, 0xa0eef9f9] ); + + // xxhash64 + test_hashes!( + XXHash64, + Float64Array, + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999) + ], + vec![ + 0xe1fd6e07fee8ad53, + 0xb71b47ebda15746c, + 0xb71b47ebda15746c, + 0x8cdde022746f8f1f, + 0x793c5c88d313eac7, + 0xc5e60e7b75d9b232 + ] + ); + + // xxhash64 with null input + test_hashes!( + XXHash64, + Float64Array, + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + None, + Some(99999999999.99999999999), + Some(-99999999999.99999999999) + ], + vec![ + 0xe1fd6e07fee8ad53, + 0xb71b47ebda15746c, + 0xb71b47ebda15746c, + 0x8cdde022746f8f1f, + 42, + 0x793c5c88d313eac7, + 0xc5e60e7b75d9b232 + ] + ); } #[test] @@ -511,6 +1079,41 @@ mod tests { ], vec![3286402344, 2486176763, 42, 142593372, 885025535, 2395000894] ); + + // xxhash64 + test_hashes!( + XXHash64, + StringArray, + vec!["hello", "bar", "", "😁", "天地"], + vec![ + 0xc3629e6318d53932, + 0xe7097b6a54378d8a, + 0x98b1582b0977e704, + 0xa80d9d5a6a523bd5, + 0xfcba5f61ac666c61 + ] + ); + // xxhash64 with null input + test_hashes!( + XXHash64, + StringArray, + vec![ + Some("hello"), + Some("bar"), + None, + Some(""), + Some("😁"), + Some("天地") + ], + vec![ + 0xc3629e6318d53932, + 0xe7097b6a54378d8a, + 42, + 0x98b1582b0977e704, + 0xa80d9d5a6a523bd5, + 0xfcba5f61ac666c61 + ] + ); } #[test] diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7238990ad9..4ca8cab77e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2050,6 +2050,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // the seed is put at the end of the arguments scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + case XxHash64(children, seed) => + val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) + if (firstUnSupportedInput.isDefined) { + withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") + return None + } + val exprs = children.map(exprToProtoInternal(_, inputs)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(LongType).get) + .setLongVal(seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*) + case Sha2(left, numBits) => if (!numBits.foldable) { withInfo(expr, "non literal numBits is not supported") diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 28027c5cb5..6f8ec2d786 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1459,6 +1459,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |select |md5(col), md5(cast(a as string)), md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), + |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128) |from test |""".stripMargin)