Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql): Adds url_download and url_upload to daft-sql #3690

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ def upload(
multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime()
# If the user specifies a single location via a string, we should upload to a single folder. Otherwise,
# if the user gave an expression, we assume that each row has a specific url to upload to.
# Consider moving the check for is_single_folder to a lower IR.
is_single_folder = isinstance(location, str)
io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config)
return Expression._from_pyexpr(
Expand Down
2 changes: 0 additions & 2 deletions src/daft-functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ arrow2 = {workspace = true}
base64 = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-io-config = {path = "../common/io-config", default-features = false}
common-runtime = {path = "../common/runtime", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
Expand All @@ -25,7 +24,6 @@ snafu.workspace = true
[features]
python = [
"common-error/python",
"common-io-config/python",
"daft-core/python",
"daft-dsl/python",
"daft-image/python",
Expand Down
18 changes: 8 additions & 10 deletions src/daft-functions/src/python/uri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use daft_dsl::python::PyExpr;
use daft_io::python::IOConfig;
use pyo3::{exceptions::PyValueError, pyfunction, PyResult};

use crate::uri::{self, download::UrlDownloadArgs, upload::UrlUploadArgs};

#[pyfunction]
pub fn url_download(
expr: PyExpr,
Expand All @@ -15,15 +17,13 @@ pub fn url_download(
"max_connections must be positive and non_zero: {max_connections}"
)));
}

Ok(crate::uri::download(
expr.into(),
let args = UrlDownloadArgs::new(
max_connections as usize,
raise_error_on_failure,
multi_thread,
Some(config.config),
)
.into())
);
Ok(uri::download(expr.into(), Some(args)).into())
}

#[pyfunction(signature = (
Expand All @@ -49,14 +49,12 @@ pub fn url_upload(
"max_connections must be positive and non_zero: {max_connections}"
)));
}
Ok(crate::uri::upload(
expr.into(),
folder_location.into(),
let args = UrlUploadArgs::new(
max_connections as usize,
raise_error_on_failure,
multi_thread,
is_single_folder,
io_config.map(|io_config| io_config.config),
)
.into())
);
Ok(uri::upload(expr.into(), folder_location.into(), Some(args)).into())
}
52 changes: 44 additions & 8 deletions src/daft-functions/src/uri/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,52 @@ use snafu::prelude::*;

use crate::InvalidArgumentSnafu;

/// Container for the keyword arguments of `url_download`
/// ex:
/// ```text
/// url_decode(input)
/// url_decode(input, max_connections=32)
/// url_decode(input, on_error='raise')
/// url_decode(input, on_error='null')
/// url_decode(input, max_connections=32, on_error='raise')
/// ```
#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct DownloadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) config: Arc<IOConfig>,
pub struct UrlDownloadArgs {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub io_config: Arc<IOConfig>,
}

impl UrlDownloadArgs {
pub fn new(
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
io_config: Option<IOConfig>,
) -> Self {
Self {
max_connections,
raise_error_on_failure,
multi_thread,
io_config: io_config.unwrap_or_default().into(),
}
}
}

impl Default for UrlDownloadArgs {
fn default() -> Self {
Self {
max_connections: 32,
raise_error_on_failure: true,
multi_thread: true,
io_config: IOConfig::default().into(),
}
}
}

#[typetag::serde]
impl ScalarUDF for DownloadFunction {
impl ScalarUDF for UrlDownloadArgs {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand All @@ -34,7 +70,7 @@ impl ScalarUDF for DownloadFunction {
max_connections,
raise_error_on_failure,
multi_thread,
config,
io_config,
} = self;

match inputs {
Expand All @@ -47,7 +83,7 @@ impl ScalarUDF for DownloadFunction {
*max_connections,
*raise_error_on_failure,
*multi_thread,
config.clone(),
io_config.clone(),
Some(io_stats),
)?;
Ok(result.into_series())
Expand Down
52 changes: 10 additions & 42 deletions src/daft-functions/src/uri/mod.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,18 @@
mod download;
mod upload;
pub mod download;
pub mod upload;

use common_io_config::IOConfig;
use daft_dsl::{functions::ScalarFunction, ExprRef};
use download::DownloadFunction;
use upload::UploadFunction;
use download::UrlDownloadArgs;
use upload::UrlUploadArgs;

/// Creates a `url_download` ExprRef from the positional and optional named arguments.
#[must_use]
pub fn download(
input: ExprRef,
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
config: Option<IOConfig>,
) -> ExprRef {
ScalarFunction::new(
DownloadFunction {
max_connections,
raise_error_on_failure,
multi_thread,
config: config.unwrap_or_default().into(),
},
vec![input],
)
.into()
pub fn download(input: ExprRef, args: Option<UrlDownloadArgs>) -> ExprRef {
ScalarFunction::new(args.unwrap_or_default(), vec![input]).into()
}

/// Creates a `url_upload` ExprRef from the positional and optional named arguments.
#[must_use]
pub fn upload(
input: ExprRef,
location: ExprRef,
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
is_single_folder: bool,
config: Option<IOConfig>,
) -> ExprRef {
ScalarFunction::new(
UploadFunction {
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
config: config.unwrap_or_default().into(),
},
vec![input, location],
)
.into()
pub fn upload(input: ExprRef, location: ExprRef, args: Option<UrlUploadArgs>) -> ExprRef {
ScalarFunction::new(args.unwrap_or_default(), vec![input, location]).into()
}
48 changes: 39 additions & 9 deletions src/daft-functions/src/uri/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,46 @@ use futures::{StreamExt, TryStreamExt};
use serde::Serialize;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct UploadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) is_single_folder: bool,
pub(super) config: Arc<IOConfig>,
pub struct UrlUploadArgs {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub is_single_folder: bool,
pub io_config: Arc<IOConfig>,
}

impl UrlUploadArgs {
pub fn new(
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
is_single_folder: bool,
io_config: Option<IOConfig>,
) -> Self {
Self {
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
io_config: io_config.unwrap_or_default().into(),
}
}
}

impl Default for UrlUploadArgs {
fn default() -> Self {
Self {
max_connections: 32,
raise_error_on_failure: true,
multi_thread: true,
is_single_folder: false,
io_config: IOConfig::default().into(),
}
}
}

#[typetag::serde]
impl ScalarUDF for UploadFunction {
impl ScalarUDF for UrlUploadArgs {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand All @@ -29,11 +59,11 @@ impl ScalarUDF for UploadFunction {

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
let Self {
config,
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
io_config,
} = self;

match inputs {
Expand All @@ -44,7 +74,7 @@ impl ScalarUDF for UploadFunction {
*raise_error_on_failure,
*multi_thread,
*is_single_folder,
config.clone(),
io_config.clone(),
None,
),
_ => Err(DaftError::ValueError(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ mod tests {
use common_scan_info::Pushdowns;
use daft_core::prelude::*;
use daft_dsl::{col, lit};
use daft_functions::uri::download::UrlDownloadArgs;
use rstest::rstest;

use crate::{
Expand Down Expand Up @@ -435,7 +436,10 @@ mod tests {
/// Tests that we can't pushdown a filter into a ScanOperator if it has an udf-ish expression.
#[test]
fn filter_with_udf_not_pushed_down_into_scan() -> DaftResult<()> {
let pred = daft_functions::uri::download(col("a"), 1, true, true, None);
let pred = daft_functions::uri::download(
col("a"),
Some(UrlDownloadArgs::new(1, true, true, None)),
);
let plan = dummy_scan_node(dummy_scan_operator(vec![
Field::new("a", DataType::Int64),
Field::new("b", DataType::Utf8),
Expand Down
31 changes: 30 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat,
SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric,
SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs,
SQLModuleTemporal, SQLModuleUtf8,
SQLModuleTemporal, SQLModuleUri, SQLModuleUtf8,
},
planner::SQLPlanner,
unsupported_sql_err,
Expand All @@ -36,6 +36,7 @@
functions.register::<SQLModuleSketch>();
functions.register::<SQLModuleStructs>();
functions.register::<SQLModuleTemporal>();
functions.register::<SQLModuleUri>();
functions.register::<SQLModuleUtf8>();
functions.register::<SQLModuleConfig>();
functions.add_fn("coalesce", SQLCoalesce {});
Expand Down Expand Up @@ -375,3 +376,31 @@
}
}
}

/// A namespace for function argument parsing helpers.
pub(crate) mod args {
use common_io_config::IOConfig;

use super::SQLFunctionArguments;
use crate::{error::PlannerError, modules::config::expr_to_iocfg, unsupported_sql_err};

/// Parses on_error => Literal['raise', 'null'] = 'raise' or err.
pub(crate) fn parse_on_error(args: &SQLFunctionArguments) -> Result<bool, PlannerError> {
match args.try_get_named::<String>("on_error")?.as_deref() {
None => Ok(true),
Some("raise") => Ok(true),
Some("null") => Ok(false),
Some(other) => {
unsupported_sql_err!("Expected on_error to be 'raise' or 'null', found '{other}'")

Check warning on line 394 in src/daft-sql/src/functions.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/functions.rs#L393-L394

Added lines #L393 - L394 were not covered by tests
}
}
}

/// Parses io_config which is used in several SQL functions.
pub(crate) fn parse_io_config(args: &SQLFunctionArguments) -> Result<IOConfig, PlannerError> {
args.get_named("io_config")
.map(expr_to_iocfg)
.transpose()
.map(|op| op.unwrap_or_default())
}
}
17 changes: 2 additions & 15 deletions src/daft-sql/src/modules/image/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use crate::{
error::{PlannerError, SQLPlannerResult},
functions::{SQLFunction, SQLFunctionArguments},
functions::{self, SQLFunction, SQLFunctionArguments},
unsupported_sql_err,
};

Expand All @@ -21,20 +21,7 @@
_ => unsupported_sql_err!("Expected mode to be a string"),
})
.transpose()?;

let raise_on_error = args
.get_named("on_error")
.map(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() {
"raise" => Ok(true),
"null" => Ok(false),
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),
},
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),
})
.transpose()?
.unwrap_or(true);

let raise_on_error = functions::args::parse_on_error(&args)?;

Check warning on line 24 in src/daft-sql/src/modules/image/decode.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/image/decode.rs#L24

Added line #L24 was not covered by tests
Ok(Self {
mode,
raise_on_error,
Expand Down
Loading
Loading