Skip to content

Commit

Permalink
feat: add functions to daft-connect (#3780)
Browse files Browse the repository at this point in the history
adds the following:
- all of spark's [math
functions](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#math-functions)
we currently have implemented.
- all of spark's [string
functions](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#string-functions)
we currently have implemented
- all of spark's [normal
functions](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#normal-functions)
we currently have implemented
  • Loading branch information
universalmind303 authored and jessie-young committed Feb 14, 2025
1 parent b0face1 commit 79f642a
Show file tree
Hide file tree
Showing 13 changed files with 529 additions and 53 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ common-runtime = {path = "src/common/runtime", default-features = false}
daft-context = {path = "src/daft-context"}
daft-core = {path = "src/daft-core"}
daft-dsl = {path = "src/daft-dsl"}
daft-functions = {path = "src/daft-functions"}
daft-functions-json = {path = "src/daft-functions-json"}
daft-hash = {path = "src/daft-hash"}
daft-local-execution = {path = "src/daft-local-execution"}
daft-logical-plan = {path = "src/daft-logical-plan"}
Expand Down
4 changes: 3 additions & 1 deletion src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"]
daft-context = {workspace = true, optional = true, features = ["python"]}
daft-core = {workspace = true, optional = true, features = ["python"]}
daft-dsl = {workspace = true, optional = true, features = ["python"]}
daft-functions = {workspace = true, optional = true, features = ["python"]}
daft-logical-plan = {workspace = true, optional = true, features = ["python"]}
daft-micropartition = {workspace = true, optional = true, features = [
"python"
Expand Down Expand Up @@ -47,7 +48,8 @@ python = [
"dep:daft-sql",
"dep:daft-recordbatch",
"dep:daft-context",
"dep:daft-catalog"
"dep:daft-catalog",
"dep:daft-functions"
]

[lints]
Expand Down
63 changes: 62 additions & 1 deletion src/daft-connect/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
use std::{collections::HashMap, sync::Arc};

use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use once_cell::sync::Lazy;
use partition_transform::PartitionTransformFunctions;
use spark_connect::Expression;

use crate::{error::ConnectResult, spark_analyzer::SparkAnalyzer};
use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer};
mod aggregate;
mod core;
mod math;
mod partition_transform;
mod string;

pub(crate) static CONNECT_FUNCTIONS: Lazy<SparkFunctions> = Lazy::new(|| {
let mut functions = SparkFunctions::new();
functions.register::<aggregate::AggregateFunctions>();
functions.register::<core::CoreFunctions>();
functions.register::<math::MathFunctions>();
functions.register::<PartitionTransformFunctions>();
functions.register::<string::StringFunctions>();
functions
});

Expand Down Expand Up @@ -53,3 +66,51 @@ pub trait FunctionModule {
/// Register this module to the given [SparkFunctions] table.
fn register(_parent: &mut SparkFunctions);
}

struct UnaryFunction(fn(ExprRef) -> ExprRef);

impl<T> SparkFunction for T
where
T: ScalarUDF + 'static + Clone,
{
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> ConnectResult<daft_dsl::ExprRef> {
let sf = ScalarFunction::new(
self.clone(),
args.iter()
.map(|arg| analyzer.to_daft_expr(arg))
.collect::<ConnectResult<Vec<_>>>()?,
);
Ok(sf.into())
}
}

impl SparkFunction for UnaryFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> ConnectResult<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;
Ok(self.0(arg))
}
_ => invalid_argument_err!("requires exactly one argument"),
}
}
}

struct Todo;
impl SparkFunction for Todo {
fn to_expr(
&self,
_args: &[Expression],
_analyzer: &SparkAnalyzer,
) -> ConnectResult<daft_dsl::ExprRef> {
invalid_argument_err!("Function not implemented")
}
}
47 changes: 47 additions & 0 deletions src/daft-connect/src/functions/aggregate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use daft_core::count_mode::CountMode;
use daft_dsl::col;
use daft_schema::dtype::DataType;
use spark_connect::Expression;

use super::{FunctionModule, SparkFunction, UnaryFunction};
use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer};

pub struct AggregateFunctions;

impl FunctionModule for AggregateFunctions {
fn register(parent: &mut super::SparkFunctions) {
parent.add_fn("count", CountFunction);
parent.add_fn("mean", UnaryFunction(|arg| arg.mean()));
parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev()));
parent.add_fn("min", UnaryFunction(|arg| arg.min()));
parent.add_fn("max", UnaryFunction(|arg| arg.max()));
parent.add_fn("sum", UnaryFunction(|arg| arg.sum()));
}
}

struct CountFunction;

impl SparkFunction for CountFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> ConnectResult<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;

let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) {
col("*")
} else {
arg
};

let count = arg.count(CountMode::All).cast(&DataType::Int64);

Ok(count)
}
_ => invalid_argument_err!("requires exactly one argument"),
}
}
}
86 changes: 43 additions & 43 deletions src/daft-connect/src/functions/core.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use daft_core::count_mode::CountMode;
use daft_dsl::{binary_op, col, ExprRef, Operator};
use daft_schema::dtype::DataType;
use daft_dsl::{binary_op, Operator};
use daft_functions::{coalesce::Coalesce, float::IsNan};
use daft_sql::sql_expr;
use spark_connect::Expression;

use super::{FunctionModule, SparkFunction};
use super::{FunctionModule, SparkFunction, Todo, UnaryFunction};
use crate::{
error::{ConnectError, ConnectResult},
invalid_argument_err,
Expand Down Expand Up @@ -32,21 +32,37 @@ impl FunctionModule for CoreFunctions {
parent.add_fn("^", BinaryOpFunction(Operator::Xor));
parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft));
parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight));

// Normal Functions
// https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#normal-functions

parent.add_fn("coalesce", Coalesce {});
parent.add_fn("input_file_name", Todo);
parent.add_fn("isnan", IsNan {});
parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null()));

parent.add_fn("monotically_increasing_id", Todo);
parent.add_fn("named_struct", Todo);
parent.add_fn("nanvl", Todo);
parent.add_fn("rand", Todo);
parent.add_fn("randn", Todo);
parent.add_fn("spark_partition_id", Todo);
parent.add_fn("when", Todo);
parent.add_fn("bitwise_not", Todo);
parent.add_fn("bitwiseNOT", Todo);
parent.add_fn("expr", SqlExpr);
parent.add_fn("greatest", Todo);
parent.add_fn("least", Todo);

// parent.add_fn("isnan", UnaryFunction(|arg| arg.is_nan()));

parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null()));
parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null()));
parent.add_fn("not", UnaryFunction(|arg| arg.not()));
parent.add_fn("sum", UnaryFunction(|arg| arg.sum()));
parent.add_fn("mean", UnaryFunction(|arg| arg.mean()));
parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev()));
parent.add_fn("min", UnaryFunction(|arg| arg.min()));
parent.add_fn("max", UnaryFunction(|arg| arg.max()));
parent.add_fn("count", CountFunction);
}
}

pub struct BinaryOpFunction(Operator);
pub struct UnaryFunction(fn(ExprRef) -> ExprRef);
pub struct CountFunction;

impl SparkFunction for BinaryOpFunction {
fn to_expr(
Expand All @@ -70,43 +86,27 @@ impl SparkFunction for BinaryOpFunction {
}
}

impl SparkFunction for UnaryFunction {
struct SqlExpr;
impl SparkFunction for SqlExpr {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> ConnectResult<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;
Ok(self.0(arg))
}
_ => invalid_argument_err!("requires exactly one argument"),
}
}
}

impl SparkFunction for CountFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> ConnectResult<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;

let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) {
col("*")
} else {
arg
};
let args = args
.iter()
.map(|arg| analyzer.to_daft_expr(arg))
.collect::<ConnectResult<Vec<_>>>()?;

let count = arg.count(CountMode::All).cast(&DataType::Int64);
let [sql] = args.as_slice() else {
invalid_argument_err!("expr requires exactly 1 argument");
};

Ok(count)
}
_ => invalid_argument_err!("requires exactly one argument"),
}
let sql = sql
.as_ref()
.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| ConnectError::invalid_argument("expr argument must be a string"))?;
Ok(sql_expr(sql)?)
}
}
Loading

0 comments on commit 79f642a

Please sign in to comment.