From a29816246fa0ed49f6be2efdce68756afcdbfb12 Mon Sep 17 00:00:00 2001 From: Xianjin Date: Thu, 20 Jun 2024 23:04:53 +0800 Subject: [PATCH 1/2] chore: Move some utility methods to submodules of scalar_funcs --- core/benches/hash.rs | 12 +- .../datafusion/expressions/scalar_funcs.rs | 142 ++---------------- .../scalar_funcs/hash_expressions.rs | 108 +++++++++++++ .../expressions/scalar_funcs/hex.rs | 63 ++++++-- 4 files changed, 178 insertions(+), 147 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs diff --git a/core/benches/hash.rs b/core/benches/hash.rs index b878ebea5..d66d58925 100644 --- a/core/benches/hash.rs +++ b/core/benches/hash.rs @@ -19,8 +19,7 @@ mod common; use arrow_array::ArrayRef; -use comet::execution::datafusion::expressions::scalar_funcs::spark_murmur3_hash; -use comet::execution::datafusion::spark_hash::create_xxhash64_hashes; +use comet::execution::datafusion::expressions::scalar_funcs::{spark_murmur3_hash, spark_xxhash64}; use comet::execution::kernels::hash; use common::*; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; @@ -100,12 +99,15 @@ fn criterion_benchmark(c: &mut Criterion) { }, ); group.bench_function(BenchmarkId::new("xxhash64", BATCH_SIZE), |b| { - let input = vec![a3.clone(), a4.clone()]; - let mut dst = vec![0; BATCH_SIZE]; + let inputs = &[ + ColumnarValue::Array(a3.clone()), + ColumnarValue::Array(a4.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(42i64))), + ]; b.iter(|| { for _ in 0..NUM_ITER { - create_xxhash64_hashes(&input, &mut dst).unwrap(); + spark_xxhash64(inputs).unwrap(); } }); }); diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 3c7af8676..7b0d40e83 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -15,14 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::{ - any::Any, - cmp::min, - fmt::{Debug, Write}, - sync::Arc, -}; +use std::{any::Any, cmp::min, fmt::Debug, sync::Arc}; -use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, @@ -30,7 +24,7 @@ use arrow::{ }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray}; +use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array}; use arrow_schema::DataType; use datafusion::{ execution::FunctionRegistry, @@ -39,8 +33,8 @@ use datafusion::{ physical_plan::ColumnarValue, }; use datafusion_common::{ - cast::{as_binary_array, as_generic_string_array}, - exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, + cast::as_generic_string_array, exec_err, internal_err, DataFusionError, + Result as DataFusionResult, ScalarValue, }; use datafusion_expr::ScalarUDF; use num::{ @@ -53,11 +47,15 @@ mod unhex; use unhex::spark_unhex; mod hex; -use hex::spark_hex; +use hex::{spark_hex, wrap_digest_result_as_hex_string}; mod chr; use chr::spark_chr; +pub mod hash_expressions; +// exposed for benchmark only +pub use hash_expressions::{spark_murmur3_hash, spark_xxhash64}; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -635,125 +633,3 @@ fn spark_decimal_div( let result = result.with_data_type(DataType::Decimal128(p3, s3)); Ok(ColumnarValue::Array(Arc::new(result))) } - -pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { - let length = args.len(); - let seed = &args[length - 1]; - match seed { - ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { - // iterate over the arguments to find out the length of the array - let num_rows = args[0..args.len() - 1] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let mut hashes: Vec = vec![0_u32; num_rows]; - hashes.fill(*seed as u32); - let arrays = args[0..args.len() - 1] - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => { - scalar.clone().to_array_of_size(num_rows).unwrap() - } - }) - .collect::>(); - create_murmur3_hashes(&arrays, &mut hashes)?; - if num_rows == 1 { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( - hashes[0] as i32, - )))) - } else { - let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); - Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) - } - } - _ => { - internal_err!( - "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", - seed - ) - } - } -} - -fn spark_xxhash64(args: &[ColumnarValue]) -> Result { - let length = args.len(); - let seed = &args[length - 1]; - match seed { - ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { - // iterate over the arguments to find out the length of the array - let num_rows = args[0..args.len() - 1] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let mut hashes: Vec = vec![0_u64; num_rows]; - hashes.fill(*seed as u64); - let arrays = args[0..args.len() - 1] - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => { - scalar.clone().to_array_of_size(num_rows).unwrap() - } - }) - .collect::>(); - create_xxhash64_hashes(&arrays, &mut hashes)?; - if num_rows == 1 { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( - hashes[0] as i64, - )))) - } else { - let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); - Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) - } - } - _ => { - internal_err!( - "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", - seed - ) - } - } -} - -#[inline] -fn hex_encode>(data: T) -> String { - let mut s = String::with_capacity(data.as_ref().len() * 2); - for b in data.as_ref() { - // Writing to a string never errors, so we can unwrap here. - write!(&mut s, "{b:02x}").unwrap(); - } - s -} - -fn wrap_digest_result_as_hex_string( - args: &[ColumnarValue], - digest: ScalarFunctionImplementation, -) -> Result { - let value = digest(args)?; - match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringArray = binary_array - .iter() - .map(|opt| opt.map(hex_encode::<_>)) - .collect(); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( - ScalarValue::Utf8(opt.map(hex_encode::<_>)), - )), - _ => { - exec_err!( - "digest function should return binary value, but got: {:?}", - value.data_type() - ) - } - } -} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs new file mode 100644 index 000000000..707eba57e --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; +use arrow_array::{ArrayRef, Int32Array, Int64Array}; +use datafusion_common::{internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u32; num_rows]; + hashes.fill(*seed as u32); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_murmur3_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( + hashes[0] as i32, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", + seed + ) + } + } +} + +pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index ea572574a..3091eb323 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -26,24 +26,69 @@ use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, - exec_err, DataFusionError, + exec_err, DataFusionError, ScalarValue, }; +use datafusion_expr::ScalarFunctionImplementation; use std::fmt::Write; fn hex_int64(num: i64) -> String { format!("{:X}", num) } -fn hex_bytes>(bytes: T) -> Result { - let bytes = bytes.as_ref(); - let length = bytes.len(); - let mut hex_string = String::with_capacity(length * 2); - for &byte in bytes { - write!(&mut hex_string, "{:02X}", byte)?; +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } } + s +} + +#[inline(always)] +fn hex_strings>(data: T) -> String { + hex_encode(data, true) +} + +#[inline(always)] +fn hex_bytes>(bytes: T) -> Result { + let hex_string = hex_encode(bytes, false); Ok(hex_string) } +pub(super) fn wrap_digest_result_as_hex_string( + args: &[ColumnarValue], + digest: ScalarFunctionImplementation, +) -> Result { + let value = digest(args)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_strings::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_strings::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} + pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return Err(DataFusionError::Internal( @@ -246,14 +291,14 @@ mod test { fn test_dictionary_hex_binary() { let mut input_builder = BinaryDictionaryBuilder::::new(); input_builder.append_value("1"); - input_builder.append_value("1"); + input_builder.append_value("j"); input_builder.append_null(); input_builder.append_value("3"); let input = input_builder.finish(); let mut expected_builder = StringBuilder::new(); expected_builder.append_value("31"); - expected_builder.append_value("31"); + expected_builder.append_value("6A"); expected_builder.append_null(); expected_builder.append_value("33"); let expected = expected_builder.finish(); From d25c37c96fb68f2261b31bf38cb7a7fccdf6e72a Mon Sep 17 00:00:00 2001 From: Xianjin Date: Fri, 21 Jun 2024 10:23:04 +0800 Subject: [PATCH 2/2] Address comments --- .../datafusion/expressions/scalar_funcs.rs | 3 +- .../scalar_funcs/hash_expressions.rs | 36 +++++++++++++++++-- .../expressions/scalar_funcs/hex.rs | 31 ++-------------- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 7b0d40e83..e55425642 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -47,13 +47,14 @@ mod unhex; use unhex::spark_unhex; mod hex; -use hex::{spark_hex, wrap_digest_result_as_hex_string}; +use hex::spark_hex; mod chr; use chr::spark_chr; pub mod hash_expressions; // exposed for benchmark only +use hash_expressions::wrap_digest_result_as_hex_string; pub use hash_expressions::{spark_murmur3_hash, spark_xxhash64}; macro_rules! make_comet_scalar_udf { diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs index 707eba57e..67d728162 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. +use crate::execution::datafusion::expressions::scalar_funcs::hex::hex_strings; use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; -use arrow_array::{ArrayRef, Int32Array, Int64Array}; -use datafusion_common::{internal_err, DataFusionError, ScalarValue}; -use datafusion_expr::ColumnarValue; +use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray}; +use datafusion_common::cast::as_binary_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; use std::sync::Arc; +/// Spark compatible murmur3 hash in vectorized execution fashion pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { let length = args.len(); let seed = &args[length - 1]; @@ -64,6 +67,7 @@ pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result Result { let length = args.len(); let seed = &args[length - 1]; @@ -106,3 +110,29 @@ pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result Result { + let value = digest(args)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_strings::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_strings::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 3091eb323..5191e53fa 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -26,9 +26,8 @@ use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, - exec_err, DataFusionError, ScalarValue, + exec_err, DataFusionError, }; -use datafusion_expr::ScalarFunctionImplementation; use std::fmt::Write; fn hex_int64(num: i64) -> String { @@ -53,7 +52,7 @@ fn hex_encode>(data: T, lower_case: bool) -> String { } #[inline(always)] -fn hex_strings>(data: T) -> String { +pub(super) fn hex_strings>(data: T) -> String { hex_encode(data, true) } @@ -63,32 +62,6 @@ fn hex_bytes>(bytes: T) -> Result { Ok(hex_string) } -pub(super) fn wrap_digest_result_as_hex_string( - args: &[ColumnarValue], - digest: ScalarFunctionImplementation, -) -> Result { - let value = digest(args)?; - match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringArray = binary_array - .iter() - .map(|opt| opt.map(hex_strings::<_>)) - .collect(); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( - ScalarValue::Utf8(opt.map(hex_strings::<_>)), - )), - _ => { - exec_err!( - "digest function should return binary value, but got: {:?}", - value.data_type() - ) - } - } -} - pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return Err(DataFusionError::Internal(