Skip to content

Commit

Permalink
feat: Support murmur3_hash and sha2 family hash functions (#226)
Browse files Browse the repository at this point in the history
* feat: Support murmur3_hash and sha2 family hash functions

* address comments

* apply scalafix

* ensure crypto_expressions feature is enabled
  • Loading branch information
advancedxy authored Apr 26, 2024
1 parent 49bf503 commit 8485558
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 74 deletions.
4 changes: 2 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ chrono = { version = "0.4", default-features = false, features = ["clock"] }
chrono-tz = { version = "0.8" }
paste = "1.0.14"
datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" }
datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["unicode_expressions"] }
datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" }
datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["unicode_expressions", "crypto_expressions"] }
datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["crypto_expressions"]}
datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", default-features = false, features = ["unicode_expressions"] }
unicode-segmentation = "^1.10.1"
once_cell = "1.18.0"
Expand Down
206 changes: 139 additions & 67 deletions core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use std::{any::Any, cmp::min, fmt::Debug, str::FromStr, sync::Arc};
use std::{
any::Any,
cmp::min,
fmt::{Debug, Write},
str::FromStr,
sync::Arc,
};

use crate::execution::datafusion::spark_hash::create_hashes;
use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array};
use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray};
use arrow_schema::DataType;
use datafusion::{
execution::FunctionRegistry,
Expand All @@ -35,8 +42,8 @@ use datafusion::{
physical_plan::ColumnarValue,
};
use datafusion_common::{
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
Result as DataFusionResult, ScalarValue,
cast::{as_binary_array, as_generic_string_array},
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
};
use datafusion_physical_expr::{math_expressions, udf::ScalarUDF};
use num::{
Expand All @@ -45,89 +52,75 @@ use num::{
};
use unicode_segmentation::UnicodeSegmentation;

macro_rules! make_comet_scalar_udf {
($name:expr, $func:ident, $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
$name.to_string(),
Signature::variadic_any(Volatility::Immutable),
$data_type.clone(),
Arc::new(move |args| $func(args, &$data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}};
($name:expr, $func:expr, without $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
$name.to_string(),
Signature::variadic_any(Volatility::Immutable),
$data_type,
$func,
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}};
}

/// Create a physical scalar function.
pub fn create_comet_physical_fun(
fun_name: &str,
data_type: DataType,
registry: &dyn FunctionRegistry,
) -> Result<ScalarFunctionDefinition, DataFusionError> {
let sha2_functions = ["sha224", "sha256", "sha384", "sha512"];
match fun_name {
"ceil" => {
let scalar_func = CometScalarFunction::new(
"ceil".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_ceil(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
make_comet_scalar_udf!("ceil", spark_ceil, data_type)
}
"floor" => {
let scalar_func = CometScalarFunction::new(
"floor".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_floor(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
make_comet_scalar_udf!("floor", spark_floor, data_type)
}
"rpad" => {
let scalar_func = CometScalarFunction::new(
"rpad".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(spark_rpad),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
}
"round" => {
let scalar_func = CometScalarFunction::new(
"round".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_round(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
make_comet_scalar_udf!("round", spark_round, data_type)
}
"unscaled_value" => {
let scalar_func = CometScalarFunction::new(
"unscaled_value".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(spark_unscaled_value),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
let func = Arc::new(spark_unscaled_value);
make_comet_scalar_udf!("unscaled_value", func, without data_type)
}
"make_decimal" => {
let scalar_func = CometScalarFunction::new(
"make_decimal".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_make_decimal(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
}
"decimal_div" => {
let scalar_func = CometScalarFunction::new(
"decimal_div".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_decimal_div(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
}
"murmur3_hash" => {
let func = Arc::new(spark_murmur3_hash);
make_comet_scalar_udf!("murmur3_hash", func, without data_type)
}
sha if sha2_functions.contains(&sha) => {
// Spark requires hex string as the result of sha2 functions, we have to wrap the
// result of digest functions as hex string
let func = registry.udf(sha)?;
let wrapped_func = Arc::new(move |args: &[ColumnarValue]| {
wrap_digest_result_as_hex_string(args, func.fun())
});
let spark_func_name = "spark".to_owned() + sha;
make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type)
}
_ => {
let fun = BuiltinScalarFunction::from_str(fun_name);
Expand Down Expand Up @@ -629,3 +622,82 @@ fn spark_decimal_div(
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}

fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
let length = args.len();
let seed = &args[length - 1];
match seed {
ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => {
// iterate over the arguments to find out the length of the array
let num_rows = args[0..args.len() - 1]
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
ColumnarValue::Scalar(_) => None,
})
.unwrap_or(1);
let mut hashes: Vec<u32> = vec![0_u32; num_rows];
hashes.fill(*seed as u32);
let arrays = args[0..args.len() - 1]
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.clone().to_array_of_size(num_rows).unwrap()
}
})
.collect::<Vec<ArrayRef>>();
create_hashes(&arrays, &mut hashes)?;
if num_rows == 1 {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(
hashes[0] as i32,
))))
} else {
let hashes: Vec<i32> = hashes.into_iter().map(|x| x as i32).collect();
Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes))))
}
}
_ => {
internal_err!(
"The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.",
seed
)
}
}
}

#[inline]
fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
let mut s = String::with_capacity(data.as_ref().len() * 2);
for b in data.as_ref() {
// Writing to a string never errors, so we can unwrap here.
write!(&mut s, "{b:02x}").unwrap();
}
s
}

fn wrap_digest_result_as_hex_string(
args: &[ColumnarValue],
digest: ScalarFunctionImplementation,
) -> Result<ColumnarValue, DataFusionError> {
let value = digest(args)?;
match value {
ColumnarValue::Array(array) => {
let binary_array = as_binary_array(&array)?;
let string_array: StringArray = binary_array
.iter()
.map(|opt| opt.map(hex_encode::<_>))
.collect();
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar(
ScalarValue::Utf8(opt.map(hex_encode::<_>)),
)),
_ => {
exec_err!(
"digest function should return binary value, but got: {:?}",
value.data_type()
)
}
}
}
43 changes: 40 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1613,10 +1613,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
optExprWithInfo(optExpr, expr, castExpr)

case Md5(child) =>
val castExpr = Cast(child, StringType)
val childExpr = exprToProtoInternal(castExpr, inputs)
val childExpr = exprToProtoInternal(child, inputs)
val optExpr = scalarExprToProto("md5", childExpr)
optExprWithInfo(optExpr, expr, castExpr)
optExprWithInfo(optExpr, expr, child)

case OctetLength(child) =>
val castExpr = Cast(child, StringType)
Expand Down Expand Up @@ -1954,6 +1953,44 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
None
}

case Murmur3Hash(children, seed) =>
val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType))
if (firstUnSupportedInput.isDefined) {
withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}")
return None
}
val exprs = children.map(exprToProtoInternal(_, inputs))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(IntegerType).get)
.setIntVal(seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)

case Sha2(left, numBits) =>
if (!numBits.foldable) {
withInfo(expr, "non literal numBits is not supported")
return None
}
// it's possible for spark to dynamically compute the number of bits from input
// expression, however DataFusion does not support that yet.
val childExpr = exprToProtoInternal(left, inputs)
val bits = numBits.eval().asInstanceOf[Int]
val algorithm = bits match {
case 224 => "sha224"
case 256 | 0 => "sha256"
case 384 => "sha384"
case 512 => "sha512"
case _ =>
null
}
if (algorithm == null) {
exprToProtoInternal(Literal(null, StringType), inputs)
} else {
scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
}

case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
26 changes: 24 additions & 2 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -981,8 +981,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

// TODO: enable this when we add md5 function to Comet
ignore("md5") {
test("md5") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
Expand Down Expand Up @@ -1405,4 +1404,27 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("hash functions") {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(s"create table $table(col string, a int, b float) using parquet")
sql(s"""
|insert into $table values
|('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999)
|, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999)
|""".stripMargin)
checkSparkAnswerAndOperator("""
|select
|md5(col), md5(cast(a as string)), md5(cast(b as string)),
|hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col),
|sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128)
|from test
|""".stripMargin)
}
}
}
}

}

0 comments on commit 8485558

Please sign in to comment.