From c208c4ceced45a3120567bcb636acb4a960b391e Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 7 Feb 2025 11:38:25 -0600 Subject: [PATCH 1/4] add math functions to spark --- Cargo.lock | 1 + Cargo.toml | 2 + src/daft-connect/Cargo.toml | 4 +- src/daft-connect/src/functions.rs | 43 +++++- src/daft-connect/src/functions/core.rs | 22 +-- src/daft-connect/src/functions/math.rs | 162 +++++++++++++++++++++++ src/daft-functions/src/numeric/round.rs | 10 +- src/daft-functions/src/python/numeric.rs | 2 +- src/daft-sql/src/modules/numeric.rs | 2 +- 9 files changed, 223 insertions(+), 25 deletions(-) create mode 100644 src/daft-connect/src/functions/math.rs diff --git a/Cargo.lock b/Cargo.lock index fd5cba331a..e33b8894f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2018,6 +2018,7 @@ dependencies = [ "daft-context", "daft-core", "daft-dsl", + "daft-functions", "daft-logical-plan", "daft-micropartition", "daft-scan", diff --git a/Cargo.toml b/Cargo.toml index 228ac4f1b3..4288d3bd54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -202,6 +202,8 @@ common-runtime = {path = "src/common/runtime", default-features = false} daft-context = {path = "src/daft-context"} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} +daft-functions = {path = "src/daft-functions"} +daft-functions-json = {path = "src/daft-functions-json"} daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} daft-logical-plan = {path = "src/daft-logical-plan"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 46daa95f18..71c08e2513 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -10,6 +10,7 @@ daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"] daft-context = {workspace = true, optional = true, features = ["python"]} daft-core = {workspace = true, optional = true, features = ["python"]} daft-dsl = {workspace = true, optional = true, features = ["python"]} +daft-functions = {workspace = true, optional = true, features = ["python"]} daft-logical-plan = {workspace = true, optional = true, features = ["python"]} daft-micropartition = {workspace = true, optional = true, features = [ "python" @@ -45,7 +46,8 @@ python = [ "dep:daft-sql", "dep:daft-table", "dep:daft-context", - "dep:daft-catalog" + "dep:daft-catalog", + "dep:daft-functions" ] [lints] diff --git a/src/daft-connect/src/functions.rs b/src/daft-connect/src/functions.rs index aa31d82cac..0cd809d8a6 100644 --- a/src/daft-connect/src/functions.rs +++ b/src/daft-connect/src/functions.rs @@ -1,14 +1,20 @@ use std::{collections::HashMap, sync::Arc}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; use once_cell::sync::Lazy; use spark_connect::Expression; -use crate::{error::ConnectResult, spark_analyzer::SparkAnalyzer}; +use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; mod core; +mod math; pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SparkFunctions::new(); functions.register::(); + functions.register::(); functions }); @@ -53,3 +59,38 @@ pub trait FunctionModule { /// Register this module to the given [SparkFunctions] table. fn register(_parent: &mut SparkFunctions); } + +pub struct UnaryFunction(fn(ExprRef) -> ExprRef); +impl SparkFunction for T +where + T: ScalarUDF + 'static + Clone, +{ + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + let sf = ScalarFunction::new( + self.clone(), + args.iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()?, + ); + Ok(sf.into()) + } +} +impl SparkFunction for UnaryFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + match args { + [arg] => { + let arg = analyzer.to_daft_expr(arg)?; + Ok(self.0(arg)) + } + _ => invalid_argument_err!("requires exactly one argument"), + } + } +} diff --git a/src/daft-connect/src/functions/core.rs b/src/daft-connect/src/functions/core.rs index b4fdae6424..60501ff395 100644 --- a/src/daft-connect/src/functions/core.rs +++ b/src/daft-connect/src/functions/core.rs @@ -1,9 +1,9 @@ use daft_core::count_mode::CountMode; -use daft_dsl::{binary_op, col, ExprRef, Operator}; +use daft_dsl::{binary_op, col, Operator}; use daft_schema::dtype::DataType; use spark_connect::Expression; -use super::{FunctionModule, SparkFunction}; +use super::{FunctionModule, SparkFunction, UnaryFunction}; use crate::{ error::{ConnectError, ConnectResult}, invalid_argument_err, @@ -45,7 +45,7 @@ impl FunctionModule for CoreFunctions { } pub struct BinaryOpFunction(Operator); -pub struct UnaryFunction(fn(ExprRef) -> ExprRef); + pub struct CountFunction; impl SparkFunction for BinaryOpFunction { @@ -70,22 +70,6 @@ impl SparkFunction for BinaryOpFunction { } } -impl SparkFunction for UnaryFunction { - fn to_expr( - &self, - args: &[Expression], - analyzer: &SparkAnalyzer, - ) -> ConnectResult { - match args { - [arg] => { - let arg = analyzer.to_daft_expr(arg)?; - Ok(self.0(arg)) - } - _ => invalid_argument_err!("requires exactly one argument"), - } - } -} - impl SparkFunction for CountFunction { fn to_expr( &self, diff --git a/src/daft-connect/src/functions/math.rs b/src/daft-connect/src/functions/math.rs new file mode 100644 index 0000000000..be79c5f12f --- /dev/null +++ b/src/daft-connect/src/functions/math.rs @@ -0,0 +1,162 @@ +use daft_dsl::LiteralValue; +use daft_functions::numeric::{ + abs::Abs, + cbrt::Cbrt, + ceil::Ceil, + exp::Exp, + floor::Floor, + log::{log, Ln, Log10, Log2}, + round::round, + sqrt::Sqrt, + trigonometry::{ + ArcCos, ArcCosh, ArcSin, ArcSinh, ArcTan, ArcTanh, Atan2, Cos, Cot, Degrees, Radians, Sin, + Tan, + }, +}; +use spark_connect::Expression; + +use super::{FunctionModule, SparkFunction}; +use crate::{ + error::{ConnectError, ConnectResult}, + invalid_argument_err, + spark_analyzer::SparkAnalyzer, +}; + +// see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#math-functions +pub struct MathFunctions; + +impl FunctionModule for MathFunctions { + fn register(parent: &mut super::SparkFunctions) { + parent.add_fn("sqrt", Sqrt {}); + parent.add_fn("abs", Abs {}); + parent.add_fn("acos", ArcCos); + parent.add_fn("acosh", ArcCosh); + parent.add_fn("asin", ArcSin); + parent.add_fn("asinh", ArcSinh); + parent.add_fn("atan", ArcTan); + parent.add_fn("atanh", ArcTanh); + parent.add_fn("atan2", Atan2 {}); + // parent.add_fn("bin", todo!()); + parent.add_fn("cbrt", Cbrt {}); + parent.add_fn("ceil", Ceil {}); + parent.add_fn("ceiling", Ceil {}); + // parent.add_fn("conv", todo!()); + parent.add_fn("cos", Cos {}); + // parent.add_fn("cosh", Cosh{}); + parent.add_fn("cot", Cot {}); + // parent.add_fn("csc", Csc{}); + // parent.add_fn("e", E{}); + parent.add_fn("exp", Exp {}); + // parent.add_fn("expm1", Expm1{}); + // parent.add_fn("factorial", Factorial{}); + parent.add_fn("floor", Floor {}); + // parent.add_fn("hex", Hex{}); + // parent.add_fn("unhex", UnHex{}); + // parent.add_fn("hypot", Hypot{}); + parent.add_fn("ln", Ln {}); + parent.add_fn("log", LogFunction); + parent.add_fn("log10", Log10 {}); + // parent.add_fn("log1p", Log{}); + parent.add_fn("log2", Log2 {}); + // parent.add_fn("negate", Negate{}) + // parent.add_fn("negative", Negative{}) + // parent.add_fn("pi", Pi{}) + // parent.add_fn("pmod", Pmod{}) + // parent.add_fn("positive", Positive{}) + // parent.add_fn("pow", Pow{}) + // parent.add_fn("power", Pow{}) + // parent.add_fn("rint", Rint{}) + parent.add_fn("round", RoundFunction); + // parent.add_fn("bround", BRound{}) + // parent.add_fn("sec", Sec{}) + // parent.add_fn("shiftleft", ShiftLeft{}) + // parent.add_fn("shiftright", ShiftRight{}) + // parent.add_fn("sign", Sign{}) + // parent.add_fn("signum", Signum{}) + parent.add_fn("sin", Sin {}); + // parent.add_fn("sinh", Sinh{}) + parent.add_fn("tan", Tan {}); + // parent.add_fn("tanh", Tanh{}) + // parent.add_fn("toDegrees", ToDegrees{}) + // parent.add_fn("try_add", TryAdd{}) + // parent.add_fn("try_avg", TryAvg{}) + // parent.add_fn("try_divide", TryDivide{}) + // parent.add_fn("try_multiply", TryMultiply{}) + // parent.add_fn("try_subtract", TrySubtract{}) + // parent.add_fn("try_sum", TrySum{}) + // parent.add_fn("try_to_binary", TryToBinary{}) + // parent.add_fn("try_to_number", TryToNumber{}) + parent.add_fn("degrees", Degrees {}); + // parent.add_fn("toRadians", ToRadians{}) + parent.add_fn("radians", Radians {}); + // parent.add_fn("width_bucket", WidthBucket{}) + // + } +} + +struct LogFunction; +impl SparkFunction for LogFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + let args = args + .iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()?; + + let [input, base] = args.as_slice() else { + invalid_argument_err!("log requires exactly 2 arguments"); + }; + + let base = match base.as_ref().as_literal() { + Some(LiteralValue::Int8(i)) => *i as f64, + Some(LiteralValue::UInt8(u)) => *u as f64, + Some(LiteralValue::Int16(i)) => *i as f64, + Some(LiteralValue::UInt16(u)) => *u as f64, + Some(LiteralValue::Int32(i)) => *i as f64, + Some(LiteralValue::UInt32(u)) => *u as f64, + Some(LiteralValue::Int64(i)) => *i as f64, + Some(LiteralValue::UInt64(u)) => *u as f64, + Some(LiteralValue::Float64(f)) => *f, + _ => invalid_argument_err!("log base must be a number"), + }; + Ok(log(input.clone(), base)) + } +} + +struct RoundFunction; + +impl SparkFunction for RoundFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + let mut args = args + .iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()? + .into_iter(); + + let input = args + .next() + .ok_or_else(|| ConnectError::invalid_argument("Expected 1 input arg, got 0"))?; + + let scale = match args.next().as_ref().and_then(|e| e.as_literal()) { + Some(LiteralValue::Int8(i)) => Some(*i as i32), + Some(LiteralValue::UInt8(u)) => Some(*u as i32), + Some(LiteralValue::Int16(i)) => Some(*i as i32), + Some(LiteralValue::UInt16(u)) => Some(*u as i32), + Some(LiteralValue::Int32(i)) => Some(*i), + Some(LiteralValue::UInt32(u)) => Some(*u as i32), + Some(LiteralValue::Int64(i)) => Some(*i as i32), + Some(LiteralValue::UInt64(u)) => Some(*u as i32), + None => None, + _ => invalid_argument_err!("round precision must be an integer"), + }; + + Ok(round(input, scale)) + } +} diff --git a/src/daft-functions/src/numeric/round.rs b/src/daft-functions/src/numeric/round.rs index 73ac6def0a..1c59448fc6 100644 --- a/src/daft-functions/src/numeric/round.rs +++ b/src/daft-functions/src/numeric/round.rs @@ -36,6 +36,12 @@ impl ScalarUDF for Round { } #[must_use] -pub fn round(input: ExprRef, decimal: i32) -> ExprRef { - ScalarFunction::new(Round { decimal }, vec![input]).into() +pub fn round(input: ExprRef, decimal: Option) -> ExprRef { + ScalarFunction::new( + Round { + decimal: decimal.unwrap_or_default(), + }, + vec![input], + ) + .into() } diff --git a/src/daft-functions/src/python/numeric.rs b/src/daft-functions/src/python/numeric.rs index f3cb7e940a..d8b5ba19ad 100644 --- a/src/daft-functions/src/python/numeric.rs +++ b/src/daft-functions/src/python/numeric.rs @@ -34,5 +34,5 @@ pub fn round(expr: PyExpr, decimal: i32) -> PyResult { "decimal can not be negative: {decimal}" ))); } - Ok(crate::numeric::round::round(expr.into(), decimal).into()) + Ok(crate::numeric::round::round(expr.into(), Some(decimal)).into()) } diff --git a/src/daft-sql/src/modules/numeric.rs b/src/daft-sql/src/modules/numeric.rs index 9a4914e6c6..ab8b41a529 100644 --- a/src/daft-sql/src/modules/numeric.rs +++ b/src/daft-sql/src/modules/numeric.rs @@ -187,7 +187,7 @@ fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult Some(LiteralValue::UInt64(u)) => *u as i32, _ => invalid_operation_err!("round precision must be an integer"), }; - Ok(round(args[0].clone(), precision)) + Ok(round(args[0].clone(), Some(precision))) } SQLNumericExpr::Clip => { ensure!(args.len() == 3, "clip takes exactly three arguments"); From 97afea8b6cbc7ec31799c97fe95efba265854632 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 7 Feb 2025 12:15:17 -0600 Subject: [PATCH 2/4] add string functions to connect --- src/daft-connect/src/functions.rs | 13 ++ src/daft-connect/src/functions/math.rs | 78 ++++++------ src/daft-connect/src/functions/string.rs | 147 +++++++++++++++++++++++ 3 files changed, 199 insertions(+), 39 deletions(-) create mode 100644 src/daft-connect/src/functions/string.rs diff --git a/src/daft-connect/src/functions.rs b/src/daft-connect/src/functions.rs index 0cd809d8a6..24dbb67b9a 100644 --- a/src/daft-connect/src/functions.rs +++ b/src/daft-connect/src/functions.rs @@ -10,11 +10,13 @@ use spark_connect::Expression; use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; mod core; mod math; +mod string; pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SparkFunctions::new(); functions.register::(); functions.register::(); + functions.register::(); functions }); @@ -94,3 +96,14 @@ impl SparkFunction for UnaryFunction { } } } + +struct Todo; +impl SparkFunction for Todo { + fn to_expr( + &self, + _args: &[Expression], + _analyzer: &SparkAnalyzer, + ) -> ConnectResult { + invalid_argument_err!("Function not implemented") + } +} diff --git a/src/daft-connect/src/functions/math.rs b/src/daft-connect/src/functions/math.rs index be79c5f12f..503391def0 100644 --- a/src/daft-connect/src/functions/math.rs +++ b/src/daft-connect/src/functions/math.rs @@ -15,7 +15,7 @@ use daft_functions::numeric::{ }; use spark_connect::Expression; -use super::{FunctionModule, SparkFunction}; +use super::{FunctionModule, SparkFunction, Todo}; use crate::{ error::{ConnectError, ConnectResult}, invalid_argument_err, @@ -36,60 +36,60 @@ impl FunctionModule for MathFunctions { parent.add_fn("atan", ArcTan); parent.add_fn("atanh", ArcTanh); parent.add_fn("atan2", Atan2 {}); - // parent.add_fn("bin", todo!()); + parent.add_fn("bin", Todo); parent.add_fn("cbrt", Cbrt {}); parent.add_fn("ceil", Ceil {}); parent.add_fn("ceiling", Ceil {}); - // parent.add_fn("conv", todo!()); + parent.add_fn("conv", Todo); parent.add_fn("cos", Cos {}); - // parent.add_fn("cosh", Cosh{}); + parent.add_fn("cosh", Todo); parent.add_fn("cot", Cot {}); - // parent.add_fn("csc", Csc{}); - // parent.add_fn("e", E{}); + parent.add_fn("csc", Todo); + parent.add_fn("e", Todo); parent.add_fn("exp", Exp {}); - // parent.add_fn("expm1", Expm1{}); - // parent.add_fn("factorial", Factorial{}); + parent.add_fn("expm1", Todo); + parent.add_fn("factorial", Todo); parent.add_fn("floor", Floor {}); - // parent.add_fn("hex", Hex{}); - // parent.add_fn("unhex", UnHex{}); - // parent.add_fn("hypot", Hypot{}); + parent.add_fn("hex", Todo); + parent.add_fn("unhex", Todo); + parent.add_fn("hypot", Todo); parent.add_fn("ln", Ln {}); parent.add_fn("log", LogFunction); parent.add_fn("log10", Log10 {}); - // parent.add_fn("log1p", Log{}); + parent.add_fn("log1p", Todo); parent.add_fn("log2", Log2 {}); - // parent.add_fn("negate", Negate{}) - // parent.add_fn("negative", Negative{}) - // parent.add_fn("pi", Pi{}) - // parent.add_fn("pmod", Pmod{}) - // parent.add_fn("positive", Positive{}) - // parent.add_fn("pow", Pow{}) - // parent.add_fn("power", Pow{}) - // parent.add_fn("rint", Rint{}) + parent.add_fn("negate", Todo); + parent.add_fn("negative", Todo); + parent.add_fn("pi", Todo); + parent.add_fn("pmod", Todo); + parent.add_fn("positive", Todo); + parent.add_fn("pow", Todo); + parent.add_fn("power", Todo); + parent.add_fn("rint", Todo); parent.add_fn("round", RoundFunction); - // parent.add_fn("bround", BRound{}) - // parent.add_fn("sec", Sec{}) - // parent.add_fn("shiftleft", ShiftLeft{}) - // parent.add_fn("shiftright", ShiftRight{}) - // parent.add_fn("sign", Sign{}) - // parent.add_fn("signum", Signum{}) + parent.add_fn("bround", Todo); + parent.add_fn("sec", Todo); + parent.add_fn("shiftleft", Todo); + parent.add_fn("shiftright", Todo); + parent.add_fn("sign", Todo); + parent.add_fn("signum", Todo); parent.add_fn("sin", Sin {}); - // parent.add_fn("sinh", Sinh{}) + parent.add_fn("sinh", Todo); parent.add_fn("tan", Tan {}); - // parent.add_fn("tanh", Tanh{}) - // parent.add_fn("toDegrees", ToDegrees{}) - // parent.add_fn("try_add", TryAdd{}) - // parent.add_fn("try_avg", TryAvg{}) - // parent.add_fn("try_divide", TryDivide{}) - // parent.add_fn("try_multiply", TryMultiply{}) - // parent.add_fn("try_subtract", TrySubtract{}) - // parent.add_fn("try_sum", TrySum{}) - // parent.add_fn("try_to_binary", TryToBinary{}) - // parent.add_fn("try_to_number", TryToNumber{}) + parent.add_fn("tanh", Todo); + parent.add_fn("toDegrees", Todo); + parent.add_fn("try_add", Todo); + parent.add_fn("try_avg", Todo); + parent.add_fn("try_divide", Todo); + parent.add_fn("try_multiply", Todo); + parent.add_fn("try_subtract", Todo); + parent.add_fn("try_sum", Todo); + parent.add_fn("try_to_binary", Todo); + parent.add_fn("try_to_number", Todo); parent.add_fn("degrees", Degrees {}); - // parent.add_fn("toRadians", ToRadians{}) + parent.add_fn("toRadians", Todo); parent.add_fn("radians", Radians {}); - // parent.add_fn("width_bucket", WidthBucket{}) + parent.add_fn("width_bucket", Todo); // } } diff --git a/src/daft-connect/src/functions/string.rs b/src/daft-connect/src/functions/string.rs new file mode 100644 index 0000000000..064b70fb96 --- /dev/null +++ b/src/daft-connect/src/functions/string.rs @@ -0,0 +1,147 @@ +use daft_dsl::LiteralValue; +use daft_functions::utf8::{ + extract, extract_all, Utf8Endswith, Utf8Ilike, Utf8Left, Utf8Length, Utf8LengthBytes, Utf8Like, + Utf8Lower, Utf8Lpad, Utf8Replace, Utf8Right, Utf8Rpad, Utf8Split, Utf8Startswith, Utf8Substr, + Utf8Upper, +}; +use spark_connect::Expression; + +use super::{FunctionModule, SparkFunction, Todo}; +use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; + +// see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#string-functions +pub struct StringFunctions; + +impl FunctionModule for StringFunctions { + fn register(parent: &mut super::SparkFunctions) { + parent.add_fn("ascii", Todo); + parent.add_fn("base64", Todo); + parent.add_fn("bit_length", Todo); + parent.add_fn("btrim", Todo); + parent.add_fn("char", Todo); + parent.add_fn("character_length", Utf8Length {}); + parent.add_fn("char_length", Utf8Length {}); + parent.add_fn("concat_ws", Todo); + parent.add_fn("contains", daft_functions::utf8::Utf8Contains {}); + parent.add_fn("decode", Todo); + parent.add_fn("elt", Todo); + parent.add_fn("encode", Utf8Endswith {}); + parent.add_fn("endswith", Todo); + parent.add_fn("find_in_set", Todo); + parent.add_fn("format_number", Todo); + parent.add_fn("format_string", Todo); + parent.add_fn("ilike", Utf8Ilike {}); + parent.add_fn("initcap", Todo); + parent.add_fn("instr", Todo); + parent.add_fn("lcase", Todo); + parent.add_fn("length", Utf8LengthBytes {}); + parent.add_fn("like", Utf8Like {}); + parent.add_fn("lower", Utf8Lower {}); + parent.add_fn("left", Utf8Left {}); + parent.add_fn("levenshtein", Todo); + parent.add_fn("locate", Todo); + parent.add_fn("lpad", Utf8Lpad {}); + parent.add_fn("ltrim", Todo); + parent.add_fn("mask", Todo); + parent.add_fn("octet_length", Todo); + parent.add_fn("parse_url", Todo); + parent.add_fn("position", Todo); + parent.add_fn("printf", Todo); + parent.add_fn("rlike", Todo); + parent.add_fn("regexp", Todo); + parent.add_fn("regexp_like", Todo); + parent.add_fn("regexp_count", Todo); + parent.add_fn("regexp_extract", RegexpExtract); + parent.add_fn("regexp_extract_all", RegexpExtractAll); + parent.add_fn("regexp_replace", Utf8Replace { regex: true }); + parent.add_fn("regexp_substr", Todo); + parent.add_fn("regexp_instr", Todo); + parent.add_fn("replace", Utf8Replace { regex: false }); + parent.add_fn("right", Utf8Right {}); + parent.add_fn("ucase", Todo); + parent.add_fn("unbase64", Todo); + parent.add_fn("rpad", Utf8Rpad {}); + parent.add_fn("repeat", Todo); + parent.add_fn("rtrim", Todo); + parent.add_fn("soundex", Todo); + parent.add_fn("split", Utf8Split { regex: false }); + parent.add_fn("split_part", Todo); + parent.add_fn("startswith", Utf8Startswith {}); + parent.add_fn("substr", Utf8Substr {}); + parent.add_fn("substring", Utf8Substr {}); + parent.add_fn("substring_index", Todo); + parent.add_fn("overlay", Todo); + parent.add_fn("sentences", Todo); + parent.add_fn("to_binary", Todo); + parent.add_fn("to_char", Todo); + parent.add_fn("to_number", Todo); + parent.add_fn("to_varchar", Todo); + parent.add_fn("translate", Todo); + parent.add_fn("trim", Todo); + parent.add_fn("upper", Utf8Upper {}); + parent.add_fn("url_decode", Todo); + parent.add_fn("url_encode", Todo); + } +} + +struct RegexpExtract; +impl SparkFunction for RegexpExtract { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + let args = args + .iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()?; + + let [input, pattern, idx] = args.as_slice() else { + invalid_argument_err!("regexp_extract requires exactly 3 arguments"); + }; + + let idx = match idx.as_ref().as_literal() { + Some(LiteralValue::Int8(i)) => *i as usize, + Some(LiteralValue::UInt8(u)) => *u as usize, + Some(LiteralValue::Int16(i)) => *i as usize, + Some(LiteralValue::UInt16(u)) => *u as usize, + Some(LiteralValue::Int32(i)) => *i as usize, + Some(LiteralValue::UInt32(u)) => *u as usize, + Some(LiteralValue::Int64(i)) => *i as usize, + Some(LiteralValue::UInt64(u)) => *u as usize, + _ => invalid_argument_err!("regexp_extract index must be a number"), + }; + Ok(extract(input.clone(), pattern.clone(), idx)) + } +} + +struct RegexpExtractAll; +impl SparkFunction for RegexpExtractAll { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + let args = args + .iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()?; + + let [input, pattern, idx] = args.as_slice() else { + invalid_argument_err!("regexp_extract requires exactly 3 arguments"); + }; + + let idx = match idx.as_ref().as_literal() { + Some(LiteralValue::Int8(i)) => *i as usize, + Some(LiteralValue::UInt8(u)) => *u as usize, + Some(LiteralValue::Int16(i)) => *i as usize, + Some(LiteralValue::UInt16(u)) => *u as usize, + Some(LiteralValue::Int32(i)) => *i as usize, + Some(LiteralValue::UInt32(u)) => *u as usize, + Some(LiteralValue::Int64(i)) => *i as usize, + Some(LiteralValue::UInt64(u)) => *u as usize, + _ => invalid_argument_err!("regexp_extract index must be a number"), + }; + Ok(extract_all(input.clone(), pattern.clone(), idx)) + } +} From 5f5845957016c31631ce1d6d6860f006cbe7b105 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 7 Feb 2025 12:32:28 -0600 Subject: [PATCH 3/4] add "normal" functions --- src/daft-connect/src/functions.rs | 3 + src/daft-connect/src/functions/aggregate.rs | 46 ++++++++++++++ src/daft-connect/src/functions/core.rs | 70 +++++++++++++-------- src/daft-functions/src/float/mod.rs | 8 +-- 4 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 src/daft-connect/src/functions/aggregate.rs diff --git a/src/daft-connect/src/functions.rs b/src/daft-connect/src/functions.rs index 24dbb67b9a..3738769c5c 100644 --- a/src/daft-connect/src/functions.rs +++ b/src/daft-connect/src/functions.rs @@ -8,12 +8,14 @@ use once_cell::sync::Lazy; use spark_connect::Expression; use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; +mod aggregate; mod core; mod math; mod string; pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SparkFunctions::new(); + functions.register::(); functions.register::(); functions.register::(); functions.register::(); @@ -81,6 +83,7 @@ where Ok(sf.into()) } } + impl SparkFunction for UnaryFunction { fn to_expr( &self, diff --git a/src/daft-connect/src/functions/aggregate.rs b/src/daft-connect/src/functions/aggregate.rs new file mode 100644 index 0000000000..6c45f270c8 --- /dev/null +++ b/src/daft-connect/src/functions/aggregate.rs @@ -0,0 +1,46 @@ +use daft_core::count_mode::CountMode; +use daft_dsl::col; +use daft_schema::dtype::DataType; +use spark_connect::Expression; + +use super::{FunctionModule, SparkFunction, UnaryFunction}; +use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; + +pub struct AggregateFunctions; + +impl FunctionModule for AggregateFunctions { + fn register(parent: &mut super::SparkFunctions) { + parent.add_fn("count", CountFunction); + parent.add_fn("mean", UnaryFunction(|arg| arg.mean())); + parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev())); + parent.add_fn("min", UnaryFunction(|arg| arg.min())); + parent.add_fn("max", UnaryFunction(|arg| arg.max())); + parent.add_fn("sum", UnaryFunction(|arg| arg.sum())); + } +} + +struct CountFunction; +impl SparkFunction for CountFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + match args { + [arg] => { + let arg = analyzer.to_daft_expr(arg)?; + + let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) { + col("*") + } else { + arg + }; + + let count = arg.count(CountMode::All).cast(&DataType::Int64); + + Ok(count) + } + _ => invalid_argument_err!("requires exactly one argument"), + } + } +} diff --git a/src/daft-connect/src/functions/core.rs b/src/daft-connect/src/functions/core.rs index 60501ff395..6140f54cea 100644 --- a/src/daft-connect/src/functions/core.rs +++ b/src/daft-connect/src/functions/core.rs @@ -1,9 +1,9 @@ -use daft_core::count_mode::CountMode; -use daft_dsl::{binary_op, col, Operator}; -use daft_schema::dtype::DataType; +use daft_dsl::{binary_op, Operator}; +use daft_functions::{coalesce::Coalesce, float::IsNan}; +use daft_sql::sql_expr; use spark_connect::Expression; -use super::{FunctionModule, SparkFunction, UnaryFunction}; +use super::{FunctionModule, SparkFunction, Todo, UnaryFunction}; use crate::{ error::{ConnectError, ConnectResult}, invalid_argument_err, @@ -32,22 +32,38 @@ impl FunctionModule for CoreFunctions { parent.add_fn("^", BinaryOpFunction(Operator::Xor)); parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft)); parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight)); + + // Normal Functions + // https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#normal-functions + + parent.add_fn("coalesce", Coalesce {}); + parent.add_fn("input_file_name", Todo); + parent.add_fn("isnan", IsNan {}); + parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null())); + + parent.add_fn("monotically_increasing_id", Todo); + parent.add_fn("named_struct", Todo); + parent.add_fn("nanvl", Todo); + parent.add_fn("rand", Todo); + parent.add_fn("randn", Todo); + parent.add_fn("spark_partition_id", Todo); + parent.add_fn("when", Todo); + parent.add_fn("bitwise_not", Todo); + parent.add_fn("bitwiseNOT", Todo); + parent.add_fn("expr", SqlExpr); + parent.add_fn("greatest", Todo); + parent.add_fn("least", Todo); + + // parent.add_fn("isnan", UnaryFunction(|arg| arg.is_nan())); + parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null())); parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null())); parent.add_fn("not", UnaryFunction(|arg| arg.not())); - parent.add_fn("sum", UnaryFunction(|arg| arg.sum())); - parent.add_fn("mean", UnaryFunction(|arg| arg.mean())); - parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev())); - parent.add_fn("min", UnaryFunction(|arg| arg.min())); - parent.add_fn("max", UnaryFunction(|arg| arg.max())); - parent.add_fn("count", CountFunction); } } pub struct BinaryOpFunction(Operator); -pub struct CountFunction; - impl SparkFunction for BinaryOpFunction { fn to_expr( &self, @@ -70,27 +86,27 @@ impl SparkFunction for BinaryOpFunction { } } -impl SparkFunction for CountFunction { +struct SqlExpr; +impl SparkFunction for SqlExpr { fn to_expr( &self, args: &[Expression], analyzer: &SparkAnalyzer, ) -> ConnectResult { - match args { - [arg] => { - let arg = analyzer.to_daft_expr(arg)?; - - let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) { - col("*") - } else { - arg - }; + let args = args + .iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()?; - let count = arg.count(CountMode::All).cast(&DataType::Int64); + let [sql] = args.as_slice() else { + invalid_argument_err!("expr requires exactly 1 argument"); + }; - Ok(count) - } - _ => invalid_argument_err!("requires exactly one argument"), - } + let sql = sql + .as_ref() + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| ConnectError::invalid_argument("expr argument must be a string"))?; + Ok(sql_expr(sql)?) } } diff --git a/src/daft-functions/src/float/mod.rs b/src/daft-functions/src/float/mod.rs index fb29e58db5..ac1e2ad167 100644 --- a/src/daft-functions/src/float/mod.rs +++ b/src/daft-functions/src/float/mod.rs @@ -3,7 +3,7 @@ mod is_inf; mod is_nan; mod not_nan; -pub use fill_nan::fill_nan; -pub use is_inf::is_inf; -pub use is_nan::is_nan; -pub use not_nan::not_nan; +pub use fill_nan::{fill_nan, FillNan}; +pub use is_inf::{is_inf, IsInf}; +pub use is_nan::{is_nan, IsNan}; +pub use not_nan::{not_nan, NotNan}; From cc826083f0492032c88d5d534fc1cc5de6a60fc1 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Mon, 10 Feb 2025 11:25:46 -0600 Subject: [PATCH 4/4] add partition transform functions --- src/daft-connect/src/functions.rs | 6 ++- src/daft-connect/src/functions/aggregate.rs | 1 + .../src/functions/partition_transform.rs | 48 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/functions/partition_transform.rs diff --git a/src/daft-connect/src/functions.rs b/src/daft-connect/src/functions.rs index 3738769c5c..444d67faae 100644 --- a/src/daft-connect/src/functions.rs +++ b/src/daft-connect/src/functions.rs @@ -5,12 +5,14 @@ use daft_dsl::{ ExprRef, }; use once_cell::sync::Lazy; +use partition_transform::PartitionTransformFunctions; use spark_connect::Expression; use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; mod aggregate; mod core; mod math; +mod partition_transform; mod string; pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { @@ -18,6 +20,7 @@ pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions }); @@ -64,7 +67,8 @@ pub trait FunctionModule { fn register(_parent: &mut SparkFunctions); } -pub struct UnaryFunction(fn(ExprRef) -> ExprRef); +struct UnaryFunction(fn(ExprRef) -> ExprRef); + impl SparkFunction for T where T: ScalarUDF + 'static + Clone, diff --git a/src/daft-connect/src/functions/aggregate.rs b/src/daft-connect/src/functions/aggregate.rs index 6c45f270c8..e29e6262bd 100644 --- a/src/daft-connect/src/functions/aggregate.rs +++ b/src/daft-connect/src/functions/aggregate.rs @@ -20,6 +20,7 @@ impl FunctionModule for AggregateFunctions { } struct CountFunction; + impl SparkFunction for CountFunction { fn to_expr( &self, diff --git a/src/daft-connect/src/functions/partition_transform.rs b/src/daft-connect/src/functions/partition_transform.rs new file mode 100644 index 0000000000..d78e421794 --- /dev/null +++ b/src/daft-connect/src/functions/partition_transform.rs @@ -0,0 +1,48 @@ +use daft_dsl::functions::partitioning; +use spark_connect::Expression; + +use super::{FunctionModule, SparkFunction, UnaryFunction}; +use crate::{ + error::{ConnectError, ConnectResult}, + invalid_argument_err, + spark_analyzer::SparkAnalyzer, +}; + +pub struct PartitionTransformFunctions; + +impl FunctionModule for PartitionTransformFunctions { + fn register(parent: &mut super::SparkFunctions) { + parent.add_fn("years", UnaryFunction(partitioning::years)); + parent.add_fn("months", UnaryFunction(partitioning::months)); + parent.add_fn("days", UnaryFunction(partitioning::days)); + parent.add_fn("hours", UnaryFunction(partitioning::hours)); + parent.add_fn("bucket", BucketFunction); + } +} + +struct BucketFunction; + +impl SparkFunction for BucketFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> ConnectResult { + match args { + [n_buckets, arg] => { + let n_buckets = analyzer.to_daft_expr(n_buckets)?; + let arg = analyzer.to_daft_expr(arg)?; + + let n_buckets = n_buckets + .as_literal() + .and_then(|lit| lit.as_i32()) + .ok_or_else(|| { + ConnectError::invalid_argument("first argument must be an integer") + })?; + + Ok(partitioning::iceberg_bucket(arg, n_buckets)) + } + _ => invalid_argument_err!("requires exactly two arguments"), + } + } +}