Skip to content

Commit

Permalink
Preserve projection with async UDF, consolidate code-gen.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksonrnewhouse committed Jan 15, 2024
1 parent 9085f50 commit 6c17e26
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 216 deletions.
99 changes: 9 additions & 90 deletions arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use arroyo_types::{Data, GlobalKey, JoinType, Key};
use bincode::{Decode, Encode};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use proc_macro2::{Ident, TokenStream};
use proc_macro2::Ident;
use quote::format_ident;
use quote::quote;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -378,14 +378,7 @@ pub enum Operator {
AsyncMapOperator {
name: String,
ordered: bool,
fn_name: String,
defs: String,
args: String,
null_output_assignments: String,
output_assignments: String,
null_handlers: String,
return_nullable: bool,
timeout_seconds: u64,
function_def: String,
max_concurrency: u64,
},
}
Expand Down Expand Up @@ -1640,61 +1633,15 @@ impl Program {
Operator::AsyncMapOperator {
name,
ordered,
fn_name,
defs, args,
null_output_assignments,
output_assignments,
null_handlers,
return_nullable,
timeout_seconds,
max_concurrency
function_def,
max_concurrency,
} => {
let in_k = parse_type(&input.unwrap().weight().key);
let in_t = parse_type(&input.unwrap().weight().value);
let out_t = parse_type(&output.unwrap().weight().value);

let defs: TokenStream = parse_str(defs).expect("defs");
let args: TokenStream = parse_str(args).expect("args");
let fn_name: Ident = parse_str(fn_name).expect("fn_name");
let null_output_assignments: TokenStream = parse_str(null_output_assignments).expect("null_output_assignments");
let output_assignments: TokenStream = parse_str(output_assignments).expect("output_assignments");
let null_handlers: TokenStream = parse_str(null_handlers).expect("null_handlers");

let null_output = if *return_nullable {
quote! {
let null_output = (
index.clone(),
Ok(#out_t {
#null_output_assignments
})
);
}
} else {
quote! ()
};

let udf_wrapper = quote!({
use tokio::time::error::Elapsed;
use tokio::time::{timeout, Duration};
async fn wrapper(index: usize, in_data: #in_t) -> (usize, Result<#out_t, Elapsed>) {
#defs

#null_output

#null_handlers

let udf_result = timeout(Duration::from_secs(#timeout_seconds), udfs::#fn_name(#args)).await;

let out = udf_result.map(
|udf_result| #out_t {
#output_assignments
}
);

(index, out)
}
wrapper
});
let udf_wrapper : syn::Expr = parse_str(function_def).unwrap();

quote! {
Box::new(AsyncMapOperator::<#in_k, #in_t, #out_t, _, _>::
Expand Down Expand Up @@ -2182,26 +2129,12 @@ impl From<Operator> for GrpcApi::operator::Operator {
Operator::AsyncMapOperator {
name,
ordered,
fn_name,
defs,
args,
null_output_assignments,
output_assignments,
null_handlers,
return_nullable,
timeout_seconds,
function_def,
max_concurrency,
} => GrpcOperator::AsyncMapOperator(GrpcApi::AsyncMapOperator {
name,
ordered,
fn_name,
defs,
args,
null_output_assignments,
output_assignments,
null_handlers,
return_nullable,
timeout_seconds,
function_def,
max_concurrency,
}),
Operator::ArrayMapOperator {
Expand Down Expand Up @@ -2515,26 +2448,12 @@ impl TryFrom<arroyo_rpc::grpc::api::Operator> for Operator {
GrpcOperator::AsyncMapOperator(GrpcApi::AsyncMapOperator {
name,
ordered,
fn_name,
defs,
args,
null_output_assignments,
output_assignments,
null_handlers,
return_nullable,
timeout_seconds,
function_def,
max_concurrency,
}) => Operator::AsyncMapOperator {
name,
ordered,
fn_name,
defs,
args,
null_output_assignments,
output_assignments,
null_handlers,
return_nullable,
timeout_seconds,
function_def,
max_concurrency,
},
GrpcOperator::FlattenExpressionOperator(flatten_expression) => {
Expand Down
11 changes: 2 additions & 9 deletions arroyo-rpc/proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,8 @@ message FlatMapOperator {
message AsyncMapOperator {
string name = 1;
bool ordered = 2;
string fn_name = 3;
string defs = 4;
string args = 5;
string output_assignments = 6;
string null_output_assignments = 7;
string null_handlers = 8;
bool return_nullable = 9;
uint64 timeout_seconds = 10;
uint64 max_concurrency = 11;
string function_def = 3;
uint64 max_concurrency = 4;
}

message SlidingWindowAggregator {
Expand Down
4 changes: 2 additions & 2 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
if !taken {
panic!("async udf appeared in a non-projection context");
} else {
let ident = input_context.variable_ident();
parse_quote!(#ident.clone())
// TODO: the name should be hooked into the code-gen infrastructure instead of hard-coded.
parse_quote!(async_result.clone())
}
}
}
Expand Down
89 changes: 77 additions & 12 deletions arroyo-sql/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use arrow_schema::DataType;
use arroyo_rpc::formats::Format;
use datafusion_expr::type_coercion::aggregates::{avg_return_type, sum_return_type};
use proc_macro2::TokenStream;
use quote::quote;
use quote::{format_ident, quote};
use syn::{parse_quote, parse_str};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -162,9 +162,9 @@ impl CodeGenerator<ValuePointerContext, StructDef, syn::Expr> for UnnestProjecti

#[derive(Debug, Clone)]
pub struct AsyncUdfProjection {
pub column: Column,
pub expression: Expression,
pub input_struct: StructDef,
pub async_udf: RustUdfExpression,
pub projection: Projection,
}

impl AsyncUdfProjection {
Expand All @@ -174,18 +174,83 @@ impl AsyncUdfProjection {
}

impl CodeGenerator<ValuePointerContext, StructDef, syn::Expr> for AsyncUdfProjection {
fn generate(&self, _input_context: &ValuePointerContext) -> syn::Expr {
unreachable!()
fn generate(&self, input_context: &ValuePointerContext) -> syn::Expr {
let input_struct = self.input_struct.get_type();
let output_type = self.expression_type(input_context).get_type();
let output_struct = self.projection.generate(input_context);
let input_name = input_context.variable_ident();
let mut may_not_invoke = false;
// definitions and identifiers for async udf invocation
let (initial_assignment, match_term_ids): (Vec<_>, Vec<_>) = self
.async_udf
.args
.iter()
.enumerate()
.map(|(i, (def, expr))| {
let t = expr.generate(&input_context);
let id = format_ident!("__{}", i);
let (initial_assigment, match_term, id) = match (
expr.expression_type(&input_context).is_optional(),
def.is_optional(),
) {
(true, true) => (quote!(let #id = #t), quote!(#id), id),
(true, false) => {
may_not_invoke = true;
(quote!(let #id = #t), quote!(Some(#id)), id)
}
(false, true) => (quote!(let #id = Some(#t)), quote!(#id), id),
(false, false) => (quote!(let #id = #t), quote!(#id), id),
};
(initial_assigment, (match_term, id))
})
.unzip();
let (match_terms, ids): (Vec<_>, Vec<_>) = match_term_ids.into_iter().unzip();

let function_name = format_ident!("{}", self.async_udf.name);
let args_pattern = quote!((#(#ids),*));
let timeout_seconds = self.async_udf.opts.async_timeout_seconds;
let invocation = if may_not_invoke {
// turn ids into a tuple
let match_terms = quote!((#(#match_terms),*));
let suffix = if self.async_udf.ret_type.is_optional() {
None
} else {
Some(quote!(.map(|result| Some(result))))
};
quote!(
match #args_pattern {
#match_terms => {
timeout(Duration::from_secs(#timeout_seconds), udfs:: #function_name #args_pattern).await #suffix
}
_ => {
Ok(None)
}
}
)
} else {
quote!(timeout(Duration::from_secs(#timeout_seconds), udfs:: #function_name #args_pattern).await)
};
parse_quote! {{
use tokio::time::error::Elapsed;
use tokio::time::{timeout, Duration};
async fn wrapper(
index: usize,
#input_name: #input_struct,
) -> (
usize,
Result<#output_type, Elapsed>,
) {
#(#initial_assignment;)*
let udf_result = #invocation;
(index, udf_result.map(|async_result| #output_struct))
};
wrapper
}
}
}

fn expression_type(&self, input_context: &ValuePointerContext) -> StructDef {
let field = StructField::new(
self.column.name.clone(),
self.column.relation.clone(),
self.expression.expression_type(&input_context),
);

StructDef::new(None, true, vec![field], None)
self.projection.expression_type(input_context)
}
}

Expand Down
Loading

0 comments on commit 6c17e26

Please sign in to comment.