diff --git a/Cargo.toml b/Cargo.toml index beaa22d91fa4..65dd7224b501 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-jit", "datafusion-physical-expr", "datafusion-cli", "datafusion-examples", diff --git a/datafusion-common/Cargo.toml b/datafusion-common/Cargo.toml index a5573138f4f4..30831b892b71 100644 --- a/datafusion-common/Cargo.toml +++ b/datafusion-common/Cargo.toml @@ -35,6 +35,7 @@ path = "src/lib.rs" [features] avro = ["avro-rs"] pyarrow = ["pyo3"] +jit = ["cranelift-module"] [dependencies] arrow = { version = "9.0.0", features = ["prettyprint"] } @@ -43,3 +44,4 @@ avro-rs = { version = "0.13", features = ["snappy"], optional = true } pyo3 = { version = "0.15", optional = true } sqlparser = "0.14" ordered-float = "2.10" +cranelift-module = { version = "0.81.1", optional = true } diff --git a/datafusion-common/src/error.rs b/datafusion-common/src/error.rs index 93978db1a1e3..ec59a8ac1db5 100644 --- a/datafusion-common/src/error.rs +++ b/datafusion-common/src/error.rs @@ -25,6 +25,8 @@ use std::result; use arrow::error::ArrowError; #[cfg(feature = "avro")] use avro_rs::Error as AvroError; +#[cfg(feature = "jit")] +use cranelift_module::ModuleError; use parquet::errors::ParquetError; use sqlparser::parser::ParserError; @@ -69,6 +71,9 @@ pub enum DataFusionError { /// Errors originating from outside DataFusion's core codebase. /// For example, a custom S3Error from the crate datafusion-objectstore-s3 External(GenericError), + #[cfg(feature = "jit")] + /// Error occurs during code generation + JITError(ModuleError), } impl From for DataFusionError { @@ -112,6 +117,13 @@ impl From for DataFusionError { } } +#[cfg(feature = "jit")] +impl From for DataFusionError { + fn from(e: ModuleError) -> Self { + DataFusionError::JITError(e) + } +} + impl From for DataFusionError { fn from(err: GenericError) -> Self { DataFusionError::External(err) @@ -152,6 +164,10 @@ impl Display for DataFusionError { DataFusionError::External(ref desc) => { write!(f, "External error: {}", desc) } + #[cfg(feature = "jit")] + DataFusionError::JITError(ref desc) => { + write!(f, "JIT error: {}", desc) + } } } } @@ -196,3 +212,10 @@ mod test { Ok(()) } } + +#[macro_export] +macro_rules! internal_err { + ($($arg:tt)*) => { + Err(DataFusionError::Internal(format!($($arg)*))) + }; +} diff --git a/datafusion-jit/Cargo.toml b/datafusion-jit/Cargo.toml new file mode 100644 index 000000000000..aaca90af2c0c --- /dev/null +++ b/datafusion-jit/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-jit" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "7.0.0" +homepage = "https://github.com/apache/arrow-datafusion" +repository = "https://github.com/apache/arrow-datafusion" +readme = "../README.md" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +edition = "2021" +rust-version = "1.58" + +[lib] +name = "datafusion_jit" +path = "src/lib.rs" + +[features] +jit = [] + +[dependencies] +datafusion-common = { path = "../datafusion-common", version = "7.0.0", features = ["jit"] } +cranelift = "0.81.1" +cranelift-module = "0.81.1" +cranelift-jit = "0.81.1" +cranelift-native = "0.81.1" +parking_lot = "0.12" diff --git a/datafusion-jit/src/api.rs b/datafusion-jit/src/api.rs new file mode 100644 index 000000000000..d95f9ccc7ac5 --- /dev/null +++ b/datafusion-jit/src/api.rs @@ -0,0 +1,630 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Constructing a function AST at runtime. + +use crate::ast::*; +use crate::jit::JIT; +use datafusion_common::internal_err; +use datafusion_common::{DataFusionError, Result}; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; + +/// External Function signature +struct ExternFuncSignature { + name: String, + /// pointer to the function + code: *const u8, + params: Vec, + returns: Option, +} + +#[derive(Clone, Debug)] +/// A function consisting of AST nodes that JIT can compile. +pub struct GeneratedFunction { + pub(crate) name: String, + pub(crate) params: Vec<(String, JITType)>, + pub(crate) body: Vec, + pub(crate) ret: Option<(String, JITType)>, +} + +#[derive(Default)] +/// State of Assembler, keep tracking of generated function names +/// and registered external functions. +pub struct AssemblerState { + name_next_id: HashMap, + extern_funcs: HashMap, +} + +impl AssemblerState { + /// Create a fresh function name with prefix `name`. + pub fn fresh_name(&mut self, name: impl Into) -> String { + let name = name.into(); + if !self.name_next_id.contains_key(&name) { + self.name_next_id.insert(name.clone(), 0); + } + + let id = self.name_next_id.get_mut(&name).unwrap(); + let name = format!("{}_{}", &name, id); + *id += 1; + name + } +} + +/// The very first step for constructing a function at runtime. +pub struct Assembler { + state: Arc>, +} + +impl Default for Assembler { + fn default() -> Self { + Self { + state: Arc::new(Default::default()), + } + } +} + +impl Assembler { + /// Register an external Rust function to make it accessible by runtime generated functions. + /// Parameters and return types are used to impose type safety while constructing an AST. + pub fn register_extern_fn( + &self, + name: impl Into, + ptr: *const u8, + params: Vec, + returns: Option, + ) -> Result<()> { + let extern_funcs = &mut self.state.lock().extern_funcs; + let fn_name = name.into(); + let old = extern_funcs.insert( + fn_name.clone(), + ExternFuncSignature { + name: fn_name, + code: ptr, + params, + returns, + }, + ); + + match old { + None => Ok(()), + Some(old) => internal_err!("Extern function {} already exists", old.name), + } + } + + /// Create a new FunctionBuilder with `name` prefix + pub fn new_func_builder(&self, name: impl Into) -> FunctionBuilder { + let name = self.state.lock().fresh_name(name); + FunctionBuilder::new(name, self.state.clone()) + } + + /// Create JIT env which we could compile the AST of constructed function + /// into runnable code. + pub fn create_jit(&self) -> JIT { + let symbols = self + .state + .lock() + .extern_funcs + .values() + .map(|s| (s.name.clone(), s.code)) + .collect::>(); + JIT::new(symbols) + } +} + +/// Function builder API. Stores the state while +/// we are constructing an AST for a function. +pub struct FunctionBuilder { + name: String, + params: Vec<(String, JITType)>, + ret: Option<(String, JITType)>, + fields: VecDeque>, + assembler_state: Arc>, +} + +impl FunctionBuilder { + fn new(name: impl Into, assembler_state: Arc>) -> Self { + let mut fields = VecDeque::new(); + fields.push_back(HashMap::new()); + Self { + name: name.into(), + params: Vec::new(), + ret: None, + fields, + assembler_state, + } + } + + /// Add one more parameter to the function. + pub fn param(mut self, name: impl Into, ty: JITType) -> Self { + let name = name.into(); + assert!(!self.fields.back().unwrap().contains_key(&name)); + self.params.push((name.clone(), ty)); + self.fields.back_mut().unwrap().insert(name, ty); + self + } + + /// Set return type for the function. Functions are of `void` type by default if + /// you do not set the return type. + pub fn ret(mut self, name: impl Into, ty: JITType) -> Self { + let name = name.into(); + assert!(!self.fields.back().unwrap().contains_key(&name)); + self.ret = Some((name.clone(), ty)); + self.fields.back_mut().unwrap().insert(name, ty); + self + } + + /// Enter the function body at start the building. + pub fn enter_block(&mut self) -> CodeBlock { + self.fields.push_back(HashMap::new()); + CodeBlock { + fields: &mut self.fields, + state: &self.assembler_state, + stmts: vec![], + while_state: None, + if_state: None, + fn_state: Some(GeneratedFunction { + name: self.name.clone(), + params: self.params.clone(), + body: vec![], + ret: self.ret.clone(), + }), + } + } +} + +/// Keep `while` condition expr as we are constructing while loop body. +struct WhileState { + condition: Expr, +} + +/// Keep `if-then-else` state, including condition expr, the already built +/// then statements (if we are during building the else block). +struct IfElseState { + condition: Expr, + then_stmts: Vec, + in_then: bool, +} + +impl IfElseState { + /// Move the all current statements in the `then` block and move to `else` block. + fn enter_else(&mut self, then_stmts: Vec) { + self.then_stmts = then_stmts; + self.in_then = false; + } +} + +/// Code block consists of statements and acts as anonymous namespace scope for items and variable declarations. +pub struct CodeBlock<'a> { + /// A stack that containing all defined variables so far. The variables defined + /// in the current block are at the top stack frame. + /// Fields provides a shadow semantics of the same name in outsider block, and are + /// used to guarantee type safety while constructing AST. + fields: &'a mut VecDeque>, + /// The state of Assembler, used for type checking function calls. + state: &'a Arc>, + /// Holding all statements for the current code block. + stmts: Vec, + while_state: Option, + if_state: Option, + /// Keep track of function params and return types, only valid for function main block. + fn_state: Option, +} + +impl<'a> CodeBlock<'a> { + pub fn build(&mut self) -> GeneratedFunction { + assert!( + self.fn_state.is_some(), + "Can only call build on outermost function block" + ); + let mut gen = self.fn_state.take().unwrap(); + gen.body = self.stmts.drain(..).collect::>(); + gen + } + + /// Leave the current block and returns the statements constructed. + fn leave(&mut self) -> Result { + self.fields.pop_back(); + if let Some(ref mut while_state) = self.while_state { + let WhileState { condition } = while_state; + let stmts = self.stmts.drain(..).collect::>(); + return Ok(Stmt::WhileLoop(Box::new(condition.clone()), stmts)); + } + + if let Some(ref mut if_state) = self.if_state { + let IfElseState { + condition, + then_stmts, + in_then, + } = if_state; + return if *in_then { + assert!(then_stmts.is_empty()); + let stmts = self.stmts.drain(..).collect::>(); + Ok(Stmt::IfElse(Box::new(condition.clone()), stmts, Vec::new())) + } else { + assert!(!then_stmts.is_empty()); + let then_stmts = then_stmts.drain(..).collect::>(); + let else_stmts = self.stmts.drain(..).collect::>(); + Ok(Stmt::IfElse( + Box::new(condition.clone()), + then_stmts, + else_stmts, + )) + }; + } + unreachable!() + } + + /// Enter else block. Try [if_block] first which is much easier to use. + fn enter_else(&mut self) { + self.fields.pop_back(); + self.fields.push_back(HashMap::new()); + assert!(self.if_state.is_some() && self.if_state.as_ref().unwrap().in_then); + let new_then = self.stmts.drain(..).collect::>(); + if let Some(s) = self.if_state.iter_mut().next() { + s.enter_else(new_then) + } + } + + /// Declare variable `name` of a type. + pub fn declare(&mut self, name: impl Into, ty: JITType) -> Result<()> { + let name = name.into(); + let typ = self.fields.back().unwrap().get(&name); + match typ { + Some(typ) => internal_err!( + "Variable {} of {} already exists in the current scope", + name, + typ + ), + None => { + self.fields.back_mut().unwrap().insert(name.clone(), ty); + self.stmts.push(Stmt::Declare(name, ty)); + Ok(()) + } + } + } + + fn find_type(&self, name: impl Into) -> Option { + let name = name.into(); + for scope in self.fields.iter().rev() { + let typ = scope.get(&name); + if let Some(typ) = typ { + return Some(*typ); + } + } + None + } + + /// Assignment statement. Assign a expression value to a variable. + pub fn assign(&mut self, name: impl Into, expr: Expr) -> Result<()> { + let name = name.into(); + let typ = self.find_type(&name); + match typ { + Some(typ) => { + if typ != expr.get_type() { + internal_err!( + "Variable {} of {} cannot be assigned to {}", + name, + typ, + expr.get_type() + ) + } else { + self.stmts.push(Stmt::Assign(name, Box::new(expr))); + Ok(()) + } + } + None => internal_err!("unknown identifier: {}", name), + } + } + + /// Declare variable with initialization. + pub fn declare_as(&mut self, name: impl Into, expr: Expr) -> Result<()> { + let name = name.into(); + let typ = self.fields.back().unwrap().get(&name); + match typ { + Some(typ) => { + internal_err!( + "Variable {} of {} already exists in the current scope", + name, + typ + ) + } + None => { + self.fields + .back_mut() + .unwrap() + .insert(name.clone(), expr.get_type()); + self.stmts + .push(Stmt::Declare(name.clone(), expr.get_type())); + self.stmts.push(Stmt::Assign(name, Box::new(expr))); + Ok(()) + } + } + } + + /// Call external function for side effect only. + pub fn call_stmt(&mut self, name: impl Into, args: Vec) -> Result<()> { + self.stmts.push(Stmt::Call(name.into(), args)); + Ok(()) + } + + /// Enter `while` loop block. Try [while_block] first which is much easier to use. + fn while_loop(&mut self, cond: Expr) -> Result { + if cond.get_type() != BOOL { + internal_err!("while condition must be bool") + } else { + self.fields.push_back(HashMap::new()); + Ok(CodeBlock { + fields: self.fields, + state: self.state, + stmts: vec![], + while_state: Some(WhileState { condition: cond }), + if_state: None, + fn_state: None, + }) + } + } + + /// Enter `if-then-else`'s then block. Try [if_block] first which is much easier to use. + fn if_else(&mut self, cond: Expr) -> Result { + if cond.get_type() != BOOL { + internal_err!("if condition must be bool") + } else { + self.fields.push_back(HashMap::new()); + Ok(CodeBlock { + fields: self.fields, + state: self.state, + stmts: vec![], + while_state: None, + if_state: Some(IfElseState { + condition: cond, + then_stmts: vec![], + in_then: true, + }), + fn_state: None, + }) + } + } + + /// Construct a `if-then-else` block with each part provided. + /// + /// E.g. if n == 0 { r = 0 } else { r = 1} could be write as: + /// x.if_block( + /// |cond| cond.eq(cond.id("n")?, cond.lit_i(0)), + /// |t| { + /// t.assign("r", t.lit_i(0))?; + /// Ok(()) + /// }, + /// |e| t.assign("r", t.lit_i(1))?; + /// Ok(()) + /// }, + /// )?; + pub fn if_block( + &mut self, + mut cond: C, + mut then_blk: T, + mut else_blk: E, + ) -> Result<()> + where + C: FnMut(&mut CodeBlock) -> Result, + T: FnMut(&mut CodeBlock) -> Result<()>, + E: FnMut(&mut CodeBlock) -> Result<()>, + { + let cond = cond(self)?; + let mut body = self.if_else(cond)?; + then_blk(&mut body)?; + body.enter_else(); + else_blk(&mut body)?; + let if_else = body.leave()?; + self.stmts.push(if_else); + Ok(()) + } + + /// Construct a `while` block with each part provided. + /// + /// E.g. while n != 0 { n = n - 1;} could be write as: + /// x.while_block( + /// |cond| cond.ne(cond.id("n")?, cond.lit_i(0)), + /// |w| { + /// w.assign("n", w.sub(w.id("n")?, w.lit_i(1))?)?; + /// Ok(()) + /// }, + /// )?; + pub fn while_block(&mut self, mut cond: C, mut body_blk: B) -> Result<()> + where + C: FnMut(&mut CodeBlock) -> Result, + B: FnMut(&mut CodeBlock) -> Result<()>, + { + let cond = cond(self)?; + let mut body = self.while_loop(cond)?; + body_blk(&mut body)?; + let while_stmt = body.leave()?; + self.stmts.push(while_stmt); + Ok(()) + } + + /// Create a literal `val` of `ty` type. + pub fn lit(&self, val: impl Into, ty: JITType) -> Expr { + Expr::Literal(Literal::Parsing(val.into(), ty)) + } + + /// Shorthand to create i64 literal + pub fn lit_i(&self, val: impl Into) -> Expr { + Expr::Literal(Literal::Typed(TypedLit::Int(val.into()))) + } + + /// Shorthand to create f32 literal + pub fn lit_f(&self, val: f32) -> Expr { + Expr::Literal(Literal::Typed(TypedLit::Float(val))) + } + + /// Shorthand to create f64 literal + pub fn lit_d(&self, val: f64) -> Expr { + Expr::Literal(Literal::Typed(TypedLit::Double(val))) + } + + /// Shorthand to create boolean literal + pub fn lit_b(&self, val: bool) -> Expr { + Expr::Literal(Literal::Typed(TypedLit::Bool(val))) + } + + /// Create a reference to an already defined variable. + pub fn id(&self, name: impl Into) -> Result { + let name = name.into(); + match self.find_type(&name) { + None => internal_err!("unknown identifier: {}", name), + Some(typ) => Ok(Expr::Identifier(name, typ)), + } + } + + /// Binary comparison expression: lhs == rhs + pub fn eq(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot compare {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Eq(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary comparison expression: lhs != rhs + pub fn ne(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot compare {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Ne(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary comparison expression: lhs < rhs + pub fn lt(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot compare {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Lt(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary comparison expression: lhs <= rhs + pub fn le(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot compare {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Le(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary comparison expression: lhs > rhs + pub fn gt(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot compare {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Gt(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary comparison expression: lhs >= rhs + pub fn ge(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot compare {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Ge(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary arithmetic expression: lhs + rhs + pub fn add(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot add {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Add(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary arithmetic expression: lhs - rhs + pub fn sub(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot subtract {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Sub(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary arithmetic expression: lhs * rhs + pub fn mul(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot multiply {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Mul(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Binary arithmetic expression: lhs / rhs + pub fn div(&self, lhs: Expr, rhs: Expr) -> Result { + if lhs.get_type() != rhs.get_type() { + internal_err!("cannot divide {} and {}", lhs.get_type(), rhs.get_type()) + } else { + Ok(Expr::Binary(BinaryExpr::Div(Box::new(lhs), Box::new(rhs)))) + } + } + + /// Call external function `name` with parameters + pub fn call(&self, name: impl Into, params: Vec) -> Result { + let fn_name = name.into(); + if let Some(func) = self.state.lock().extern_funcs.get(&fn_name) { + for ((i, t1), t2) in params.iter().enumerate().zip(func.params.iter()) { + if t1.get_type() != *t2 { + return internal_err!( + "Func {} need {} as arg{}, get {}", + &fn_name, + t2, + i, + t1.get_type() + ); + } + } + Ok(Expr::Call(fn_name, params, func.returns.unwrap_or(NIL))) + } else { + internal_err!("No func with the name {} exist", fn_name) + } + } +} + +impl Display for GeneratedFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "fn {}(", self.name)?; + for (i, (name, ty)) in self.params.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}: {}", name, ty)?; + } + write!(f, ") -> ")?; + if let Some((name, ty)) = &self.ret { + write!(f, "{}: {}", name, ty)?; + } else { + write!(f, "()")?; + } + writeln!(f, " {{")?; + for stmt in &self.body { + stmt.fmt_ident(4, f)?; + } + write!(f, "}}") + } +} diff --git a/datafusion-jit/src/ast.rs b/datafusion-jit/src/ast.rs new file mode 100644 index 000000000000..5d0e3bc4041e --- /dev/null +++ b/datafusion-jit/src/ast.rs @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use cranelift::codegen::ir; +use std::fmt::{Display, Formatter}; + +#[derive(Clone, Debug)] +/// Statement +pub enum Stmt { + /// if-then-else + IfElse(Box, Vec, Vec), + /// while + WhileLoop(Box, Vec), + /// assignment + Assign(String, Box), + /// call function for side effect + Call(String, Vec), + /// declare a new variable of type + Declare(String, JITType), +} + +#[derive(Copy, Clone, Debug)] +/// Shorthand typed literals +pub enum TypedLit { + Bool(bool), + Int(i64), + Float(f32), + Double(f64), +} + +#[derive(Clone, Debug)] +/// Expression +pub enum Expr { + /// literal + Literal(Literal), + /// variable + Identifier(String, JITType), + /// binary expression + Binary(BinaryExpr), + /// call function expression + Call(String, Vec, JITType), +} + +impl Expr { + pub fn get_type(&self) -> JITType { + match self { + Expr::Literal(lit) => lit.get_type(), + Expr::Identifier(_, ty) => *ty, + Expr::Binary(bin) => bin.get_type(), + Expr::Call(_, _, ty) => *ty, + } + } +} + +impl Literal { + fn get_type(&self) -> JITType { + match self { + Literal::Parsing(_, ty) => *ty, + Literal::Typed(tl) => tl.get_type(), + } + } +} + +impl TypedLit { + fn get_type(&self) -> JITType { + match self { + TypedLit::Bool(_) => BOOL, + TypedLit::Int(_) => I64, + TypedLit::Float(_) => F32, + TypedLit::Double(_) => F64, + } + } +} + +impl BinaryExpr { + fn get_type(&self) -> JITType { + match self { + BinaryExpr::Eq(_, _) => BOOL, + BinaryExpr::Ne(_, _) => BOOL, + BinaryExpr::Lt(_, _) => BOOL, + BinaryExpr::Le(_, _) => BOOL, + BinaryExpr::Gt(_, _) => BOOL, + BinaryExpr::Ge(_, _) => BOOL, + BinaryExpr::Add(lhs, _) => lhs.get_type(), + BinaryExpr::Sub(lhs, _) => lhs.get_type(), + BinaryExpr::Mul(lhs, _) => lhs.get_type(), + BinaryExpr::Div(lhs, _) => lhs.get_type(), + } + } +} + +#[derive(Clone, Debug)] +/// Binary expression +pub enum BinaryExpr { + /// == + Eq(Box, Box), + /// != + Ne(Box, Box), + /// < + Lt(Box, Box), + /// <= + Le(Box, Box), + /// > + Gt(Box, Box), + /// >= + Ge(Box, Box), + /// add + Add(Box, Box), + /// subtract + Sub(Box, Box), + /// multiply + Mul(Box, Box), + /// divide + Div(Box, Box), +} + +#[derive(Clone, Debug)] +/// Literal +pub enum Literal { + /// Parsable literal with type + Parsing(String, JITType), + /// Shorthand literals of common types + Typed(TypedLit), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +/// Type to be used in JIT +pub struct JITType { + /// The cranelift type + pub(crate) native: ir::Type, + /// re-expose inner field of `ir::Type` out for easier pattern matching + pub(crate) code: u8, +} + +/// null type as placeholder +pub const NIL: JITType = JITType { + native: ir::types::INVALID, + code: 0, +}; +/// bool +pub const BOOL: JITType = JITType { + native: ir::types::B1, + code: 0x70, +}; +/// integer of 1 byte +pub const I8: JITType = JITType { + native: ir::types::I8, + code: 0x76, +}; +/// integer of 2 bytes +pub const I16: JITType = JITType { + native: ir::types::I16, + code: 0x77, +}; +/// integer of 4 bytes +pub const I32: JITType = JITType { + native: ir::types::I32, + code: 0x78, +}; +/// integer of 8 bytes +pub const I64: JITType = JITType { + native: ir::types::I64, + code: 0x79, +}; +/// Ieee float of 32 bits +pub const F32: JITType = JITType { + native: ir::types::F32, + code: 0x7b, +}; +/// Ieee float of 64 bits +pub const F64: JITType = JITType { + native: ir::types::F64, + code: 0x7c, +}; +/// Pointer type of 32 bits +pub const R32: JITType = JITType { + native: ir::types::R32, + code: 0x7e, +}; +/// Pointer type of 64 bits +pub const R64: JITType = JITType { + native: ir::types::R64, + code: 0x7f, +}; +/// The pointer type to use based on our currently target. +pub const PTR: JITType = if std::mem::size_of::() == 8 { + R64 +} else { + R32 +}; + +impl Stmt { + /// print the statement with indentation + pub fn fmt_ident(&self, ident: usize, f: &mut Formatter) -> std::fmt::Result { + let mut ident_str = String::new(); + for _ in 0..ident { + ident_str.push(' '); + } + match self { + Stmt::IfElse(cond, then_stmts, else_stmts) => { + writeln!(f, "{}if {} {{", ident_str, cond)?; + for stmt in then_stmts { + stmt.fmt_ident(ident + 4, f)?; + } + writeln!(f, "{}}} else {{", ident_str)?; + for stmt in else_stmts { + stmt.fmt_ident(ident + 4, f)?; + } + writeln!(f, "{}}}", ident_str) + } + Stmt::WhileLoop(cond, stmts) => { + writeln!(f, "{}while {} {{", ident_str, cond)?; + for stmt in stmts { + stmt.fmt_ident(ident + 4, f)?; + } + writeln!(f, "{}}}", ident_str) + } + Stmt::Assign(name, expr) => { + writeln!(f, "{}{} = {};", ident_str, name, expr) + } + Stmt::Call(name, args) => { + writeln!( + f, + "{}{}({});", + ident_str, + name, + args.iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + ) + } + Stmt::Declare(name, ty) => { + writeln!(f, "{}let {}: {};", ident_str, name, ty) + } + } + } +} + +impl Display for Stmt { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_ident(0, f)?; + Ok(()) + } +} + +impl Display for Expr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Expr::Literal(l) => write!(f, "{}", l), + Expr::Identifier(name, _) => write!(f, "{}", name), + Expr::Binary(be) => write!(f, "{}", be), + Expr::Call(name, exprs, _) => { + write!( + f, + "{}({})", + name, + exprs + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + ) + } + } + } +} + +impl Display for Literal { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Literal::Parsing(str, _) => write!(f, "{}", str), + Literal::Typed(tl) => write!(f, "{}", tl), + } + } +} + +impl Display for TypedLit { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TypedLit::Bool(b) => write!(f, "{}", b), + TypedLit::Int(i) => write!(f, "{}", i), + TypedLit::Float(fl) => write!(f, "{}", fl), + TypedLit::Double(d) => write!(f, "{}", d), + } + } +} + +impl Display for BinaryExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + BinaryExpr::Eq(lhs, rhs) => write!(f, "{} == {}", lhs, rhs), + BinaryExpr::Ne(lhs, rhs) => write!(f, "{} != {}", lhs, rhs), + BinaryExpr::Lt(lhs, rhs) => write!(f, "{} < {}", lhs, rhs), + BinaryExpr::Le(lhs, rhs) => write!(f, "{} <= {}", lhs, rhs), + BinaryExpr::Gt(lhs, rhs) => write!(f, "{} > {}", lhs, rhs), + BinaryExpr::Ge(lhs, rhs) => write!(f, "{} >= {}", lhs, rhs), + BinaryExpr::Add(lhs, rhs) => write!(f, "{} + {}", lhs, rhs), + BinaryExpr::Sub(lhs, rhs) => write!(f, "{} - {}", lhs, rhs), + BinaryExpr::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs), + BinaryExpr::Div(lhs, rhs) => write!(f, "{} / {}", lhs, rhs), + } + } +} + +impl std::fmt::Display for JITType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::fmt::Debug for JITType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.code { + 0 => write!(f, "nil"), + 0x70 => write!(f, "bool"), + 0x76 => write!(f, "i8"), + 0x77 => write!(f, "i16"), + 0x78 => write!(f, "i32"), + 0x79 => write!(f, "i64"), + 0x7b => write!(f, "f32"), + 0x7c => write!(f, "f64"), + 0x7e => write!(f, "small_ptr"), + 0x7f => write!(f, "ptr"), + _ => write!(f, "unknown"), + } + } +} + +impl From<&str> for JITType { + fn from(x: &str) -> Self { + match x { + "bool" => BOOL, + "i8" => I8, + "i16" => I16, + "i32" => I32, + "i64" => I64, + "f32" => F32, + "f64" => F64, + "small_ptr" => R32, + "ptr" => R64, + _ => panic!("unknown type: {}", x), + } + } +} diff --git a/datafusion-jit/src/jit.rs b/datafusion-jit/src/jit.rs new file mode 100644 index 000000000000..225366b4be7a --- /dev/null +++ b/datafusion-jit/src/jit.rs @@ -0,0 +1,676 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::api::GeneratedFunction; +use crate::ast::{BinaryExpr, Expr, JITType, Literal, Stmt, TypedLit, BOOL, I64, NIL}; +use cranelift::prelude::*; +use cranelift_jit::{JITBuilder, JITModule}; +use cranelift_module::{Linkage, Module}; +use datafusion_common::internal_err; +use datafusion_common::{DataFusionError, Result}; +use std::collections::HashMap; + +/// The basic JIT class. +#[allow(clippy::upper_case_acronyms)] +pub struct JIT { + /// The function builder context, which is reused across multiple + /// FunctionBuilder instances. + builder_context: FunctionBuilderContext, + + /// The main Cranelift context, which holds the state for codegen. Cranelift + /// separates this from `Module` to allow for parallel compilation, with a + /// context per thread, though this is not the case now. + ctx: codegen::Context, + + /// The module, with the jit backend, which manages the JIT'd + /// functions. + module: JITModule, +} + +impl Default for JIT { + fn default() -> Self { + let builder = JITBuilder::new(cranelift_module::default_libcall_names()); + let module = JITModule::new(builder); + Self { + builder_context: FunctionBuilderContext::new(), + ctx: module.make_context(), + module, + } + } +} + +impl JIT { + /// New while registering external functions + pub fn new(symbols: It) -> Self + where + It: IntoIterator, + K: Into, + { + let mut flag_builder = settings::builder(); + flag_builder.set("use_colocated_libcalls", "false").unwrap(); + flag_builder.set("is_pic", "true").unwrap(); + flag_builder.set("opt_level", "speed").unwrap(); + flag_builder.set("enable_simd", "true").unwrap(); + let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { + panic!("host machine is not supported: {}", msg); + }); + let isa = isa_builder.finish(settings::Flags::new(flag_builder)); + let mut builder = + JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); + builder.symbols(symbols); + let module = JITModule::new(builder); + Self { + builder_context: FunctionBuilderContext::new(), + ctx: module.make_context(), + module, + } + } + + /// Compile the generated function into machine code. + pub fn compile(&mut self, func: GeneratedFunction) -> Result<*const u8> { + let GeneratedFunction { + name, + params, + body, + ret, + } = func; + + // Translate the AST nodes into Cranelift IR. + self.translate(params, ret, body)?; + + // Next, declare the function to jit. Functions must be declared + // before they can be called, or defined. + let id = self.module.declare_function( + &name, + Linkage::Export, + &self.ctx.func.signature, + )?; + + // Define the function to jit. This finishes compilation, although + // there may be outstanding relocations to perform. Currently, jit + // cannot finish relocations until all functions to be called are + // defined. For now, we'll just finalize the function below. + self.module.define_function(id, &mut self.ctx)?; + + // Now that compilation is finished, we can clear out the context state. + self.module.clear_context(&mut self.ctx); + + // Finalize the functions which we just defined, which resolves any + // outstanding relocations (patching in addresses, now that they're + // available). + self.module.finalize_definitions(); + + // We can now retrieve a pointer to the machine code. + let code = self.module.get_finalized_function(id); + + Ok(code) + } + + // Translate into Cranelift IR. + fn translate( + &mut self, + params: Vec<(String, JITType)>, + the_return: Option<(String, JITType)>, + stmts: Vec, + ) -> Result<()> { + for param in ¶ms { + self.ctx + .func + .signature + .params + .push(AbiParam::new(param.1.native)); + } + + let mut void_return: bool = false; + + // We currently only supports one return value, though + // Cranelift is designed to support more. + match the_return { + None => void_return = true, + Some(ref ret) => { + self.ctx + .func + .signature + .returns + .push(AbiParam::new(ret.1.native)); + } + } + + // Create the builder to build a function. + let mut builder = + FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); + + // Create the entry block, to start emitting code in. + let entry_block = builder.create_block(); + + // Since this is the entry block, add block parameters corresponding to + // the function's parameters. + builder.append_block_params_for_function_params(entry_block); + + // Tell the builder to emit code in this block. + builder.switch_to_block(entry_block); + + // And, tell the builder that this block will have no further + // predecessors. Since it's the entry block, it won't have any + // predecessors. + builder.seal_block(entry_block); + + // Walk the AST and declare all variables. + let variables = + declare_variables(&mut builder, ¶ms, &the_return, &stmts, entry_block); + + // Now translate the statements of the function body. + let mut trans = FunctionTranslator { + builder, + variables, + module: &mut self.module, + }; + for stmt in stmts { + trans.translate_stmt(stmt)?; + } + + if !void_return { + // Set up the return variable of the function. Above, we declared a + // variable to hold the return value. Here, we just do a use of that + // variable. + let return_variable = trans + .variables + .get(&the_return.as_ref().unwrap().0) + .unwrap(); + let return_value = trans.builder.use_var(*return_variable); + + // Emit the return instruction. + trans.builder.ins().return_(&[return_value]); + } else { + trans.builder.ins().return_(&[]); + } + + // Tell the builder we're done with this function. + trans.builder.finalize(); + Ok(()) + } +} + +/// A collection of state used for translating from AST nodes +/// into Cranelift IR. +struct FunctionTranslator<'a> { + builder: FunctionBuilder<'a>, + variables: HashMap, + module: &'a mut JITModule, +} + +impl<'a> FunctionTranslator<'a> { + fn translate_stmt(&mut self, stmt: Stmt) -> Result<()> { + match stmt { + Stmt::IfElse(condition, then_body, else_body) => { + self.translate_if_else(*condition, then_body, else_body) + } + Stmt::WhileLoop(condition, loop_body) => { + self.translate_while_loop(*condition, loop_body) + } + Stmt::Assign(name, expr) => self.translate_assign(name, *expr), + Stmt::Call(name, args) => { + self.translate_call_stmt(name, args, NIL)?; + Ok(()) + } + Stmt::Declare(_, _) => Ok(()), + } + } + + fn translate_typed_lit(&mut self, tl: TypedLit) -> Value { + match tl { + TypedLit::Bool(b) => self.builder.ins().bconst(BOOL.native, b), + TypedLit::Int(i) => self.builder.ins().iconst(I64.native, i), + TypedLit::Float(f) => self.builder.ins().f32const(f), + TypedLit::Double(d) => self.builder.ins().f64const(d), + } + } + + /// When you write out instructions in Cranelift, you get back `Value`s. You + /// can then use these references in other instructions. + fn translate_expr(&mut self, expr: Expr) -> Result { + match expr { + Expr::Literal(nl) => self.translate_literal(nl), + Expr::Identifier(name, _) => { + // `use_var` is used to read the value of a variable. + let variable = self.variables.get(&name).ok_or_else(|| { + DataFusionError::Internal("variable not defined".to_owned()) + })?; + Ok(self.builder.use_var(*variable)) + } + Expr::Binary(b) => self.translate_binary_expr(b), + Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret), + } + } + + fn translate_literal(&mut self, expr: Literal) -> Result { + match expr { + Literal::Parsing(literal, ty) => self.translate_string_lit(literal, ty), + Literal::Typed(lt) => Ok(self.translate_typed_lit(lt)), + } + } + + fn translate_binary_expr(&mut self, expr: BinaryExpr) -> Result { + match expr { + BinaryExpr::Eq(lhs, rhs) => { + let ty = lhs.get_type(); + if ty.code >= 0x76 && ty.code <= 0x79 { + self.translate_icmp(IntCC::Equal, *lhs, *rhs) + } else if ty.code == 0x7b || ty.code == 0x7c { + self.translate_fcmp(FloatCC::Equal, *lhs, *rhs) + } else { + internal_err!("Unsupported type {} for equal comparison", ty) + } + } + BinaryExpr::Ne(lhs, rhs) => { + let ty = lhs.get_type(); + if ty.code >= 0x76 && ty.code <= 0x79 { + self.translate_icmp(IntCC::NotEqual, *lhs, *rhs) + } else if ty.code == 0x7b || ty.code == 0x7c { + self.translate_fcmp(FloatCC::NotEqual, *lhs, *rhs) + } else { + internal_err!("Unsupported type {} for not equal comparison", ty) + } + } + BinaryExpr::Lt(lhs, rhs) => { + let ty = lhs.get_type(); + if ty.code >= 0x76 && ty.code <= 0x79 { + self.translate_icmp(IntCC::SignedLessThan, *lhs, *rhs) + } else if ty.code == 0x7b || ty.code == 0x7c { + self.translate_fcmp(FloatCC::LessThan, *lhs, *rhs) + } else { + internal_err!("Unsupported type {} for less than comparison", ty) + } + } + BinaryExpr::Le(lhs, rhs) => { + let ty = lhs.get_type(); + if ty.code >= 0x76 && ty.code <= 0x79 { + self.translate_icmp(IntCC::SignedLessThanOrEqual, *lhs, *rhs) + } else if ty.code == 0x7b || ty.code == 0x7c { + self.translate_fcmp(FloatCC::LessThanOrEqual, *lhs, *rhs) + } else { + internal_err!( + "Unsupported type {} for less than or equal comparison", + ty + ) + } + } + BinaryExpr::Gt(lhs, rhs) => { + let ty = lhs.get_type(); + if ty.code >= 0x76 && ty.code <= 0x79 { + self.translate_icmp(IntCC::SignedGreaterThan, *lhs, *rhs) + } else if ty.code == 0x7b || ty.code == 0x7c { + self.translate_fcmp(FloatCC::GreaterThan, *lhs, *rhs) + } else { + internal_err!("Unsupported type {} for greater than comparison", ty) + } + } + BinaryExpr::Ge(lhs, rhs) => { + let ty = lhs.get_type(); + if ty.code >= 0x76 && ty.code <= 0x79 { + self.translate_icmp(IntCC::SignedGreaterThanOrEqual, *lhs, *rhs) + } else if ty.code == 0x7b || ty.code == 0x7c { + self.translate_fcmp(FloatCC::GreaterThanOrEqual, *lhs, *rhs) + } else { + internal_err!( + "Unsupported type {} for greater than or equal comparison", + ty + ) + } + } + BinaryExpr::Add(lhs, rhs) => { + let ty = lhs.get_type(); + let lhs = self.translate_expr(*lhs)?; + let rhs = self.translate_expr(*rhs)?; + if ty.code >= 0x76 && ty.code <= 0x79 { + Ok(self.builder.ins().iadd(lhs, rhs)) + } else if ty.code == 0x7b || ty.code == 0x7c { + Ok(self.builder.ins().fadd(lhs, rhs)) + } else { + internal_err!("Unsupported type {} for add", ty) + } + } + BinaryExpr::Sub(lhs, rhs) => { + let ty = lhs.get_type(); + let lhs = self.translate_expr(*lhs)?; + let rhs = self.translate_expr(*rhs)?; + if ty.code >= 0x76 && ty.code <= 0x79 { + Ok(self.builder.ins().isub(lhs, rhs)) + } else if ty.code == 0x7b || ty.code == 0x7c { + Ok(self.builder.ins().fsub(lhs, rhs)) + } else { + internal_err!("Unsupported type {} for sub", ty) + } + } + BinaryExpr::Mul(lhs, rhs) => { + let ty = lhs.get_type(); + let lhs = self.translate_expr(*lhs)?; + let rhs = self.translate_expr(*rhs)?; + if ty.code >= 0x76 && ty.code <= 0x79 { + Ok(self.builder.ins().imul(lhs, rhs)) + } else if ty.code == 0x7b || ty.code == 0x7c { + Ok(self.builder.ins().fmul(lhs, rhs)) + } else { + internal_err!("Unsupported type {} for mul", ty) + } + } + BinaryExpr::Div(lhs, rhs) => { + let ty = lhs.get_type(); + let lhs = self.translate_expr(*lhs)?; + let rhs = self.translate_expr(*rhs)?; + if ty.code >= 0x76 && ty.code <= 0x79 { + Ok(self.builder.ins().udiv(lhs, rhs)) + } else if ty.code == 0x7b || ty.code == 0x7c { + Ok(self.builder.ins().fdiv(lhs, rhs)) + } else { + internal_err!("Unsupported type {} for div", ty) + } + } + } + } + + fn translate_string_lit(&mut self, lit: String, ty: JITType) -> Result { + match ty.code { + 0x70 => { + let b = lit.parse::().unwrap(); + Ok(self.builder.ins().bconst(ty.native, b)) + } + 0x76 => { + let i = lit.parse::().unwrap(); + Ok(self.builder.ins().iconst(ty.native, i as i64)) + } + 0x77 => { + let i = lit.parse::().unwrap(); + Ok(self.builder.ins().iconst(ty.native, i as i64)) + } + 0x78 => { + let i = lit.parse::().unwrap(); + Ok(self.builder.ins().iconst(ty.native, i as i64)) + } + 0x79 => { + let i = lit.parse::().unwrap(); + Ok(self.builder.ins().iconst(ty.native, i)) + } + 0x7b => { + let f = lit.parse::().unwrap(); + Ok(self.builder.ins().f32const(f)) + } + 0x7c => { + let f = lit.parse::().unwrap(); + Ok(self.builder.ins().f64const(f)) + } + _ => internal_err!("Unsupported type {} for string literal", ty), + } + } + + fn translate_assign(&mut self, name: String, expr: Expr) -> Result<()> { + // `def_var` is used to write the value of a variable. Note that + // variables can have multiple definitions. Cranelift will + // convert them into SSA form for itself automatically. + let new_value = self.translate_expr(expr)?; + let variable = self.variables.get(&*name).unwrap(); + self.builder.def_var(*variable, new_value); + Ok(()) + } + + fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result { + let lhs = self.translate_expr(lhs)?; + let rhs = self.translate_expr(rhs)?; + let c = self.builder.ins().icmp(cmp, lhs, rhs); + Ok(self.builder.ins().bint(I64.native, c)) + } + + fn translate_fcmp(&mut self, cmp: FloatCC, lhs: Expr, rhs: Expr) -> Result { + let lhs = self.translate_expr(lhs)?; + let rhs = self.translate_expr(rhs)?; + let c = self.builder.ins().fcmp(cmp, lhs, rhs); + Ok(self.builder.ins().bint(I64.native, c)) + } + + fn translate_if_else( + &mut self, + condition: Expr, + then_body: Vec, + else_body: Vec, + ) -> Result<()> { + let condition_value = self.translate_expr(condition)?; + + let then_block = self.builder.create_block(); + let else_block = self.builder.create_block(); + let merge_block = self.builder.create_block(); + + // Test the if condition and conditionally branch. + self.builder.ins().brz(condition_value, else_block, &[]); + // Fall through to then block. + self.builder.ins().jump(then_block, &[]); + + self.builder.switch_to_block(then_block); + self.builder.seal_block(then_block); + for stmt in then_body { + self.translate_stmt(stmt)?; + } + + // Jump to the merge block, passing it the block return value. + self.builder.ins().jump(merge_block, &[]); + + self.builder.switch_to_block(else_block); + self.builder.seal_block(else_block); + for stmt in else_body { + self.translate_stmt(stmt)?; + } + + // Jump to the merge block, passing it the block return value. + self.builder.ins().jump(merge_block, &[]); + + // Switch to the merge block for subsequent statements. + self.builder.switch_to_block(merge_block); + + // We've now seen all the predecessors of the merge block. + self.builder.seal_block(merge_block); + Ok(()) + } + + fn translate_while_loop( + &mut self, + condition: Expr, + loop_body: Vec, + ) -> Result<()> { + let header_block = self.builder.create_block(); + let body_block = self.builder.create_block(); + let exit_block = self.builder.create_block(); + + self.builder.ins().jump(header_block, &[]); + self.builder.switch_to_block(header_block); + + let condition_value = self.translate_expr(condition)?; + self.builder.ins().brz(condition_value, exit_block, &[]); + self.builder.ins().jump(body_block, &[]); + + self.builder.switch_to_block(body_block); + self.builder.seal_block(body_block); + + for stmt in loop_body { + self.translate_stmt(stmt)?; + } + self.builder.ins().jump(header_block, &[]); + + self.builder.switch_to_block(exit_block); + + // We've reached the bottom of the loop, so there will be no + // more backedges to the header to exits to the bottom. + self.builder.seal_block(header_block); + self.builder.seal_block(exit_block); + Ok(()) + } + + fn translate_call_expr( + &mut self, + name: String, + args: Vec, + ret: JITType, + ) -> Result { + let mut sig = self.module.make_signature(); + + // Add a parameter for each argument. + for arg in &args { + sig.params.push(AbiParam::new(arg.get_type().native)); + } + + if ret.code == 0 { + return internal_err!( + "Call function {}(..) has void type, it can not be an expression", + &name + ); + } else { + sig.returns.push(AbiParam::new(ret.native)); + } + + let callee = self + .module + .declare_function(&name, Linkage::Import, &sig) + .expect("problem declaring function"); + let local_callee = self.module.declare_func_in_func(callee, self.builder.func); + + let mut arg_values = Vec::new(); + for arg in args { + arg_values.push(self.translate_expr(arg)?) + } + let call = self.builder.ins().call(local_callee, &arg_values); + Ok(self.builder.inst_results(call)[0]) + } + + fn translate_call_stmt( + &mut self, + name: String, + args: Vec, + ret: JITType, + ) -> Result<()> { + let mut sig = self.module.make_signature(); + + // Add a parameter for each argument. + for arg in &args { + sig.params.push(AbiParam::new(arg.get_type().native)); + } + + if ret.code != 0 { + sig.returns.push(AbiParam::new(ret.native)); + } + + let callee = self + .module + .declare_function(&name, Linkage::Import, &sig) + .expect("problem declaring function"); + let local_callee = self.module.declare_func_in_func(callee, self.builder.func); + + let mut arg_values = Vec::new(); + for arg in args { + arg_values.push(self.translate_expr(arg)?) + } + let _ = self.builder.ins().call(local_callee, &arg_values); + Ok(()) + } +} + +fn typed_zero(typ: JITType, builder: &mut FunctionBuilder) -> Value { + match typ.code { + 0x70 => builder.ins().bconst(typ.native, false), + 0x76 => builder.ins().iconst(typ.native, 0), + 0x77 => builder.ins().iconst(typ.native, 0), + 0x78 => builder.ins().iconst(typ.native, 0), + 0x79 => builder.ins().iconst(typ.native, 0), + 0x7b => builder.ins().f32const(0.0), + 0x7c => builder.ins().f64const(0.0), + 0x7e => builder.ins().null(typ.native), + 0x7f => builder.ins().null(typ.native), + _ => panic!("unsupported type"), + } +} + +fn declare_variables( + builder: &mut FunctionBuilder, + params: &[(String, JITType)], + the_return: &Option<(String, JITType)>, + stmts: &[Stmt], + entry_block: Block, +) -> HashMap { + let mut variables = HashMap::new(); + let mut index = 0; + + for (i, name) in params.iter().enumerate() { + let val = builder.block_params(entry_block)[i]; + let var = declare_variable(builder, &mut variables, &mut index, &name.0, name.1); + builder.def_var(var, val); + } + + if let Some(ret) = the_return { + let zero = typed_zero(ret.1, builder); + let return_variable = + declare_variable(builder, &mut variables, &mut index, &ret.0, ret.1); + builder.def_var(return_variable, zero); + } + + for stmt in stmts { + declare_variables_in_stmt(builder, &mut variables, &mut index, stmt); + } + + variables +} + +/// Recursively descend through the AST, translating all declarations. +fn declare_variables_in_stmt( + builder: &mut FunctionBuilder, + variables: &mut HashMap, + index: &mut usize, + stmt: &Stmt, +) { + match *stmt { + Stmt::IfElse(_, ref then_body, ref else_body) => { + for stmt in then_body { + declare_variables_in_stmt(builder, variables, index, stmt); + } + for stmt in else_body { + declare_variables_in_stmt(builder, variables, index, stmt); + } + } + Stmt::WhileLoop(_, ref loop_body) => { + for stmt in loop_body { + declare_variables_in_stmt(builder, variables, index, stmt); + } + } + Stmt::Declare(ref name, typ) => { + declare_variable(builder, variables, index, name, typ); + } + _ => {} + } +} + +/// Declare a single variable declaration. +fn declare_variable( + builder: &mut FunctionBuilder, + variables: &mut HashMap, + index: &mut usize, + name: &str, + typ: JITType, +) -> Variable { + let var = Variable::new(*index); + if !variables.contains_key(name) { + variables.insert(name.into(), var); + builder.declare_var(var, typ.native); + *index += 1; + } + var +} diff --git a/datafusion-jit/src/lib.rs b/datafusion-jit/src/lib.rs new file mode 100644 index 000000000000..5642b5a9c987 --- /dev/null +++ b/datafusion-jit/src/lib.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Just-In-Time compilation to accelerate DataFusion physical plan execution. + +pub mod api; +pub mod ast; +pub mod jit; + +#[cfg(test)] +mod tests { + use crate::api::{Assembler, GeneratedFunction}; + use crate::ast::I64; + use crate::jit::JIT; + use datafusion_common::Result; + + #[test] + fn iterative_fib() -> Result<()> { + let expected = r#"fn iterative_fib_0(n: i64) -> r: i64 { + if n == 0 { + r = 0; + } else { + n = n - 1; + let a: i64; + a = 0; + r = 1; + while n != 0 { + let t: i64; + t = r; + r = r + a; + a = t; + n = n - 1; + } + } +}"#; + let assembler = Assembler::default(); + let mut builder = assembler + .new_func_builder("iterative_fib") + .param("n", I64) + .ret("r", I64); + let mut fn_body = builder.enter_block(); + + fn_body.if_block( + |cond| cond.eq(cond.id("n")?, cond.lit_i(0)), + |t| { + t.assign("r", t.lit_i(0))?; + Ok(()) + }, + |e| { + e.assign("n", e.sub(e.id("n")?, e.lit_i(1))?)?; + e.declare_as("a", e.lit_i(0))?; + e.assign("r", e.lit_i(1))?; + e.while_block( + |cond| cond.ne(cond.id("n")?, cond.lit_i(0)), + |w| { + w.declare_as("t", w.id("r")?)?; + w.assign("r", w.add(w.id("r")?, w.id("a")?)?)?; + w.assign("a", w.id("t")?)?; + w.assign("n", w.sub(w.id("n")?, w.lit_i(1))?)?; + Ok(()) + }, + )?; + Ok(()) + }, + )?; + + let gen_func = fn_body.build(); + assert_eq!(format!("{}", &gen_func), expected); + let mut jit = assembler.create_jit(); + assert_eq!(55, run_iterative_fib_code(&mut jit, gen_func, 10)?); + Ok(()) + } + + unsafe fn run_code( + jit: &mut JIT, + code: GeneratedFunction, + input: I, + ) -> Result { + // Pass the string to the JIT, and it returns a raw pointer to machine code. + let code_ptr = jit.compile(code)?; + // Cast the raw pointer to a typed function pointer. This is unsafe, because + // this is the critical point where you have to trust that the generated code + // is safe to be called. + let code_fn = core::mem::transmute::<_, fn(I) -> O>(code_ptr); + // And now we can call it! + Ok(code_fn(input)) + } + + fn run_iterative_fib_code( + jit: &mut JIT, + code: GeneratedFunction, + input: isize, + ) -> Result { + unsafe { run_code(jit, code, input) } + } +} diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index cbba899181e2..23eb7cecd96b 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -50,10 +50,13 @@ force_hash_collisions = [] avro = ["avro-rs", "num-traits", "datafusion-common/avro"] # Used to enable row format experiment row = [] +# Used to enable JIT code generation +jit = ["datafusion-jit"] [dependencies] datafusion-common = { path = "../datafusion-common", version = "7.0.0" } datafusion-expr = { path = "../datafusion-expr", version = "7.0.0" } +datafusion-jit = { path = "../datafusion-jit", version = "7.0.0", optional = true } datafusion-physical-expr = { path = "../datafusion-physical-expr", version = "7.0.0" } ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.12", features = ["raw"] } @@ -121,3 +124,8 @@ harness = false [[bench]] name = "parquet_query_sql" harness = false + +[[bench]] +name = "jit" +harness = false +required-features = ["row", "jit"] diff --git a/datafusion/benches/data_utils/mod.rs b/datafusion/benches/data_utils/mod.rs index 6ebeeb77020e..71952b4c6520 100644 --- a/datafusion/benches/data_utils/mod.rs +++ b/datafusion/benches/data_utils/mod.rs @@ -35,7 +35,8 @@ use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, /// and the result table will be of array_len in total, and then partitioned, and batched. -pub(crate) fn create_table_provider( +#[allow(dead_code)] +pub fn create_table_provider( partitions_len: usize, array_len: usize, batch_size: usize, @@ -52,7 +53,8 @@ fn seedable_rng() -> StdRng { StdRng::seed_from_u64(42) } -fn create_schema() -> Schema { +/// Create test data schema +pub fn create_schema() -> Schema { Schema::new(vec![ Field::new("utf8", DataType::Utf8, false), Field::new("f32", DataType::Float32, false), @@ -138,7 +140,9 @@ fn create_record_batch( .unwrap() } -fn create_record_batches( +/// Create record batches of `partitions_len` partitions and `batch_size` for each batch, +/// with a total number of `array_len` records +pub fn create_record_batches( schema: SchemaRef, array_len: usize, partitions_len: usize, diff --git a/datafusion/benches/jit.rs b/datafusion/benches/jit.rs new file mode 100644 index 000000000000..b198b158c319 --- /dev/null +++ b/datafusion/benches/jit.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use crate::data_utils::{create_record_batches, create_schema}; +use datafusion::row::writer::{ + bench_write_batch, bench_write_batch_jit, bench_write_batch_jit_dummy, +}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let partitions_len = 8; + let array_len = 32768 * 1024; // 2^25 + let batch_size = 2048; // 2^11 + + let schema = Arc::new(create_schema()); + let batches = + create_record_batches(schema.clone(), array_len, partitions_len, batch_size); + + c.bench_function("row serializer", |b| { + b.iter(|| { + criterion::black_box(bench_write_batch(&batches, schema.clone()).unwrap()) + }) + }); + + c.bench_function("row serializer jit", |b| { + b.iter(|| { + criterion::black_box(bench_write_batch_jit(&batches, schema.clone()).unwrap()) + }) + }); + + c.bench_function("row serializer jit codegen only", |b| { + b.iter(|| bench_write_batch_jit_dummy(schema.clone()).unwrap()) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index a9630c0f1756..0ce6e91e8ad0 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -227,7 +227,7 @@ pub use parquet; pub(crate) mod field_util; #[cfg(feature = "row")] -pub(crate) mod row; +pub mod row; pub mod from_slice; diff --git a/datafusion/src/row/mod.rs b/datafusion/src/row/mod.rs index 9875b84975e2..5cd9885238a9 100644 --- a/datafusion/src/row/mod.rs +++ b/datafusion/src/row/mod.rs @@ -17,7 +17,7 @@ //! An implementation of Row backed by raw bytes //! -//! Each tuple consists of up to three parts: [null bit set] [values] [var length data] +//! Each tuple consists of up to three parts: "`null bit set`" , "`values`" and "`var length data`" //! //! The null bit set is used for null tracking and is aligned to 1-byte. It stores //! one bit per field. @@ -52,8 +52,8 @@ use arrow::util::bit_util::{get_bit_raw, round_upto_power_of_2}; use std::fmt::Write; use std::sync::Arc; -mod reader; -mod writer; +pub mod reader; +pub mod writer; const ALL_VALID_MASK: [u8; 8] = [1, 3, 7, 15, 31, 63, 127, 255]; @@ -189,6 +189,29 @@ fn supported(schema: &Arc) -> bool { .all(|f| supported_type(f.data_type())) } +#[cfg(feature = "jit")] +#[macro_export] +/// register external functions to the assembler +macro_rules! reg_fn { + ($ASS:ident, $FN: path, $PARAM: expr, $RET: expr) => { + $ASS.register_extern_fn(fn_name($FN), $FN as *const u8, $PARAM, $RET)?; + }; +} + +#[cfg(feature = "jit")] +fn fn_name(f: T) -> &'static str { + fn type_name_of(_: T) -> &'static str { + std::any::type_name::() + } + let name = type_name_of(f); + + // Find and cut the rest of the path + match &name.rfind(':') { + Some(pos) => &name[pos + 1..name.len()], + None => name, + } +} + #[cfg(test)] mod tests { use super::*; @@ -203,10 +226,16 @@ mod tests { use crate::physical_plan::file_format::FileScanConfig; use crate::physical_plan::{collect, ExecutionPlan}; use crate::row::reader::read_as_batch; + #[cfg(feature = "jit")] + use crate::row::reader::read_as_batch_jit; use crate::row::writer::write_batch_unchecked; + #[cfg(feature = "jit")] + use crate::row::writer::write_batch_unchecked_jit; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{ceil, set_bit_raw, unset_bit_raw}; use arrow::{array::*, datatypes::*}; + #[cfg(feature = "jit")] + use datafusion_jit::api::Assembler; use rand::Rng; use DataType::*; @@ -300,7 +329,23 @@ mod tests { let mut vector = vec![0; 1024]; let row_offsets = { write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&mut vector, schema, row_offsets)? }; + let output_batch = { read_as_batch(&vector, schema, row_offsets)? }; + assert_eq!(batch, output_batch); + Ok(()) + } + + #[test] + #[allow(non_snake_case)] + #[cfg(feature = "jit")] + fn []() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, true)])); + let a = $ARRAY::from($VEC); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?; + let mut vector = vec![0; 1024]; + let assembler = Assembler::default(); + let row_offsets = + { write_batch_unchecked_jit(&mut vector, 0, &batch, 0, schema.clone(), &assembler)? }; + let output_batch = { read_as_batch_jit(&vector, schema, row_offsets, &assembler)? }; assert_eq!(batch, output_batch); Ok(()) } @@ -402,7 +447,33 @@ mod tests { let mut vector = vec![0; 8192]; let row_offsets = { write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&mut vector, schema, row_offsets)? }; + let output_batch = { read_as_batch(&vector, schema, row_offsets)? }; + assert_eq!(batch, output_batch); + Ok(()) + } + + #[test] + #[cfg(feature = "jit")] + fn test_single_binary_jit() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, true)])); + let values: Vec> = + vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")]; + let a = BinaryArray::from_opt_vec(values); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?; + let mut vector = vec![0; 8192]; + let assembler = Assembler::default(); + let row_offsets = { + write_batch_unchecked_jit( + &mut vector, + 0, + &batch, + 0, + schema.clone(), + &assembler, + )? + }; + let output_batch = + { read_as_batch_jit(&vector, schema, row_offsets, &assembler)? }; assert_eq!(batch, output_batch); Ok(()) } @@ -421,7 +492,7 @@ mod tests { let mut vector = vec![0; 20480]; let row_offsets = { write_batch_unchecked(&mut vector, 0, batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&mut vector, schema, row_offsets)? }; + let output_batch = { read_as_batch(&vector, schema, row_offsets)? }; assert_eq!(*batch, output_batch); Ok(()) @@ -445,9 +516,9 @@ mod tests { DataType::Decimal(5, 2), false, )])); - let mut vector = vec![0; 1024]; + let vector = vec![0; 1024]; let row_offsets = vec![0]; - read_as_batch(&mut vector, schema, row_offsets).unwrap(); + read_as_batch(&vector, schema, row_offsets).unwrap(); } async fn get_exec( diff --git a/datafusion/src/row/reader.rs b/datafusion/src/row/reader.rs index 779c09990ffc..213c34b574ad 100644 --- a/datafusion/src/row/reader.rs +++ b/datafusion/src/row/reader.rs @@ -18,17 +18,27 @@ //! Accessing row from raw bytes use crate::error::{DataFusionError, Result}; +#[cfg(feature = "jit")] +use crate::reg_fn; +#[cfg(feature = "jit")] +use crate::row::fn_name; use crate::row::{all_valid, get_offsets, supported, NullBitsFormatter}; use arrow::array::*; use arrow::datatypes::{DataType, Schema}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{ceil, get_bit_raw}; +#[cfg(feature = "jit")] +use datafusion_jit::api::Assembler; +#[cfg(feature = "jit")] +use datafusion_jit::api::GeneratedFunction; +#[cfg(feature = "jit")] +use datafusion_jit::ast::{I64, PTR}; use std::sync::Arc; /// Read `data` of raw-bytes rows starting at `offsets` out to a record batch pub fn read_as_batch( - data: &mut [u8], + data: &[u8], schema: Arc, offsets: Vec, ) -> Result { @@ -44,6 +54,33 @@ pub fn read_as_batch( output.output().map_err(DataFusionError::ArrowError) } +/// Read `data` of raw-bytes rows starting at `offsets` out to a record batch +#[cfg(feature = "jit")] +pub fn read_as_batch_jit( + data: &[u8], + schema: Arc, + offsets: Vec, + assembler: &Assembler, +) -> Result { + let row_num = offsets.len(); + let mut output = MutableRecordBatch::new(row_num, schema.clone()); + let mut row = RowReader::new(&schema, data); + register_read_functions(assembler)?; + let gen_func = gen_read_row(&schema, assembler)?; + let mut jit = assembler.create_jit(); + let code_ptr = jit.compile(gen_func)?; + let code_fn = unsafe { + std::mem::transmute::<_, fn(&RowReader, &mut MutableRecordBatch)>(code_ptr) + }; + + for offset in offsets.iter().take(row_num) { + row.point_to(*offset); + code_fn(&row, &mut output); + } + + output.output().map_err(DataFusionError::ArrowError) +} + macro_rules! get_idx { ($NATIVE: ident, $SELF: ident, $IDX: ident, $WIDTH: literal) => {{ $SELF.assert_index_valid($IDX); @@ -260,6 +297,114 @@ fn read_row(row: &RowReader, batch: &mut MutableRecordBatch, schema: &Arc &mut Box { + let arrays: &mut [Box] = batch.arrays.as_mut(); + &mut arrays[col_idx] +} + +#[cfg(feature = "jit")] +fn register_read_functions(asm: &Assembler) -> Result<()> { + let reader_param = vec![PTR, I64, PTR]; + reg_fn!(asm, get_array_mut, vec![PTR, I64], Some(PTR)); + reg_fn!(asm, read_field_bool, reader_param.clone(), None); + reg_fn!(asm, read_field_u8, reader_param.clone(), None); + reg_fn!(asm, read_field_u16, reader_param.clone(), None); + reg_fn!(asm, read_field_u32, reader_param.clone(), None); + reg_fn!(asm, read_field_u64, reader_param.clone(), None); + reg_fn!(asm, read_field_i8, reader_param.clone(), None); + reg_fn!(asm, read_field_i16, reader_param.clone(), None); + reg_fn!(asm, read_field_i32, reader_param.clone(), None); + reg_fn!(asm, read_field_i64, reader_param.clone(), None); + reg_fn!(asm, read_field_f32, reader_param.clone(), None); + reg_fn!(asm, read_field_f64, reader_param.clone(), None); + reg_fn!(asm, read_field_date32, reader_param.clone(), None); + reg_fn!(asm, read_field_date64, reader_param.clone(), None); + reg_fn!(asm, read_field_utf8, reader_param.clone(), None); + reg_fn!(asm, read_field_binary, reader_param.clone(), None); + reg_fn!(asm, read_field_bool_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_u8_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_u16_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_u32_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_u64_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_i8_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_i16_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_i32_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_i64_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_f32_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_f64_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_date32_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_date64_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_utf8_nf, reader_param.clone(), None); + reg_fn!(asm, read_field_binary_nf, reader_param, None); + Ok(()) +} + +#[cfg(feature = "jit")] +fn gen_read_row( + schema: &Arc, + assembler: &Assembler, +) -> Result { + use DataType::*; + let mut builder = assembler + .new_func_builder("read_row") + .param("row", PTR) + .param("batch", PTR); + let mut b = builder.enter_block(); + for (i, f) in schema.fields().iter().enumerate() { + let dt = f.data_type(); + let arr = format!("a{}", i); + b.declare_as( + &arr, + b.call("get_array_mut", vec![b.id("batch")?, b.lit_i(i as i64)])?, + )?; + let params = vec![b.id(&arr)?, b.lit_i(i as i64), b.id("row")?]; + if f.is_nullable() { + match dt { + Boolean => b.call_stmt("read_field_bool", params)?, + UInt8 => b.call_stmt("read_field_u8", params)?, + UInt16 => b.call_stmt("read_field_u16", params)?, + UInt32 => b.call_stmt("read_field_u32", params)?, + UInt64 => b.call_stmt("read_field_u64", params)?, + Int8 => b.call_stmt("read_field_i8", params)?, + Int16 => b.call_stmt("read_field_i16", params)?, + Int32 => b.call_stmt("read_field_i32", params)?, + Int64 => b.call_stmt("read_field_i64", params)?, + Float32 => b.call_stmt("read_field_f32", params)?, + Float64 => b.call_stmt("read_field_f64", params)?, + Date32 => b.call_stmt("read_field_date32", params)?, + Date64 => b.call_stmt("read_field_date64", params)?, + Utf8 => b.call_stmt("read_field_utf8", params)?, + Binary => b.call_stmt("read_field_binary", params)?, + _ => unimplemented!(), + } + } else { + match dt { + Boolean => b.call_stmt("read_field_bool_nf", params)?, + UInt8 => b.call_stmt("read_field_u8_nf", params)?, + UInt16 => b.call_stmt("read_field_u16_nf", params)?, + UInt32 => b.call_stmt("read_field_u32_nf", params)?, + UInt64 => b.call_stmt("read_field_u64_nf", params)?, + Int8 => b.call_stmt("read_field_i8_nf", params)?, + Int16 => b.call_stmt("read_field_i16_nf", params)?, + Int32 => b.call_stmt("read_field_i32_nf", params)?, + Int64 => b.call_stmt("read_field_i64_nf", params)?, + Float32 => b.call_stmt("read_field_f32_nf", params)?, + Float64 => b.call_stmt("read_field_f64_nf", params)?, + Date32 => b.call_stmt("read_field_date32_nf", params)?, + Date64 => b.call_stmt("read_field_date64_nf", params)?, + Utf8 => b.call_stmt("read_field_utf8_nf", params)?, + Binary => b.call_stmt("read_field_binary_nf", params)?, + _ => unimplemented!(), + } + } + } + Ok(b.build()) +} + macro_rules! fn_read_field { ($NATIVE: ident, $ARRAY: ident) => { paste::item! { diff --git a/datafusion/src/row/writer.rs b/datafusion/src/row/writer.rs index 698f7974c10a..2206e350bcb2 100644 --- a/datafusion/src/row/writer.rs +++ b/datafusion/src/row/writer.rs @@ -17,11 +17,22 @@ //! Reusable row writer backed by Vec to stitch attributes together +use crate::error::Result; +#[cfg(feature = "jit")] +use crate::reg_fn; +#[cfg(feature = "jit")] +use crate::row::fn_name; use crate::row::{estimate_row_width, fixed_size, get_offsets, supported}; -use arrow::array::Array; +use arrow::array::*; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{ceil, round_upto_power_of_2, set_bit_raw, unset_bit_raw}; +use datafusion_jit::api::CodeBlock; +#[cfg(feature = "jit")] +use datafusion_jit::api::{Assembler, GeneratedFunction}; +use datafusion_jit::ast::Expr; +#[cfg(feature = "jit")] +use datafusion_jit::ast::{BOOL, I64, PTR}; use std::cmp::max; use std::sync::Arc; @@ -50,6 +61,103 @@ pub fn write_batch_unchecked( offsets } +/// Append batch from `row_idx` to `output` buffer start from `offset` +/// # Panics +/// +/// This function will panic if the output buffer doesn't have enough space to hold all the rows +#[cfg(feature = "jit")] +pub fn write_batch_unchecked_jit( + output: &mut [u8], + offset: usize, + batch: &RecordBatch, + row_idx: usize, + schema: Arc, + assembler: &Assembler, +) -> Result> { + let mut writer = RowWriter::new(&schema); + let mut current_offset = offset; + let mut offsets = vec![]; + register_write_functions(assembler)?; + let gen_func = gen_write_row(&schema, assembler)?; + let mut jit = assembler.create_jit(); + let code_ptr = jit.compile(gen_func)?; + + let code_fn = unsafe { + std::mem::transmute::<_, fn(&mut RowWriter, usize, &RecordBatch)>(code_ptr) + }; + + for cur_row in row_idx..batch.num_rows() { + offsets.push(current_offset); + code_fn(&mut writer, cur_row, batch); + writer.end_padding(); + let row_width = writer.row_width; + output[current_offset..current_offset + row_width] + .copy_from_slice(writer.get_row()); + current_offset += row_width; + writer.reset() + } + Ok(offsets) +} + +#[cfg(feature = "jit")] +/// bench interpreted version write +pub fn bench_write_batch( + batches: &[Vec], + schema: Arc, +) -> Result> { + let mut writer = RowWriter::new(&schema); + let mut lengths = vec![]; + + for batch in batches.iter().flatten() { + for cur_row in 0..batch.num_rows() { + let row_width = write_row(&mut writer, cur_row, batch); + lengths.push(row_width); + writer.reset() + } + } + + Ok(lengths) +} + +#[cfg(feature = "jit")] +/// bench jit version write +pub fn bench_write_batch_jit( + batches: &[Vec], + schema: Arc, +) -> Result> { + let assembler = Assembler::default(); + let mut writer = RowWriter::new(&schema); + let mut lengths = vec![]; + register_write_functions(&assembler)?; + let gen_func = gen_write_row(&schema, &assembler)?; + let mut jit = assembler.create_jit(); + let code_ptr = jit.compile(gen_func)?; + let code_fn = unsafe { + std::mem::transmute::<_, fn(&mut RowWriter, usize, &RecordBatch)>(code_ptr) + }; + + for batch in batches.iter().flatten() { + for cur_row in 0..batch.num_rows() { + code_fn(&mut writer, cur_row, batch); + writer.end_padding(); + lengths.push(writer.row_width); + writer.reset() + } + } + Ok(lengths) +} + +#[cfg(feature = "jit")] +/// bench code generation cost +pub fn bench_write_batch_jit_dummy(schema: Arc) -> Result<()> { + let assembler = Assembler::default(); + register_write_functions(&assembler)?; + let gen_func = gen_write_row(&schema, &assembler)?; + let mut jit = assembler.create_jit(); + let _: *const u8 = jit.compile(gen_func)?; + Ok(()) +} + macro_rules! set_idx { ($WIDTH: literal, $SELF: ident, $IDX: ident, $VALUE: ident) => {{ $SELF.assert_index_valid($IDX); @@ -233,7 +341,6 @@ fn write_row(row: &mut RowWriter, row_idx: usize, batch: &RecordBatch) -> usize .zip(batch.columns().iter()) { if !col.is_null(row_idx) { - row.set_non_null_at(i); write_field(i, row_idx, col, f.data_type(), row); } else { row.set_null_at(i); @@ -244,6 +351,197 @@ fn write_row(row: &mut RowWriter, row_idx: usize, batch: &RecordBatch) -> usize row.row_width } +// we could remove this function wrapper once we find a way to call the trait method directly. +#[cfg(feature = "jit")] +fn is_null(col: &Arc, row_idx: usize) -> bool { + col.is_null(row_idx) +} + +#[cfg(feature = "jit")] +fn register_write_functions(asm: &Assembler) -> Result<()> { + let reader_param = vec![PTR, I64, PTR]; + reg_fn!(asm, RecordBatch::column, vec![PTR, I64], Some(PTR)); + reg_fn!(asm, RowWriter::set_null_at, vec![PTR, I64], None); + reg_fn!(asm, RowWriter::set_non_null_at, vec![PTR, I64], None); + reg_fn!(asm, is_null, vec![PTR, I64], Some(BOOL)); + reg_fn!(asm, write_field_bool, reader_param.clone(), None); + reg_fn!(asm, write_field_u8, reader_param.clone(), None); + reg_fn!(asm, write_field_u16, reader_param.clone(), None); + reg_fn!(asm, write_field_u32, reader_param.clone(), None); + reg_fn!(asm, write_field_u64, reader_param.clone(), None); + reg_fn!(asm, write_field_i8, reader_param.clone(), None); + reg_fn!(asm, write_field_i16, reader_param.clone(), None); + reg_fn!(asm, write_field_i32, reader_param.clone(), None); + reg_fn!(asm, write_field_i64, reader_param.clone(), None); + reg_fn!(asm, write_field_f32, reader_param.clone(), None); + reg_fn!(asm, write_field_f64, reader_param.clone(), None); + reg_fn!(asm, write_field_date32, reader_param.clone(), None); + reg_fn!(asm, write_field_date64, reader_param.clone(), None); + reg_fn!(asm, write_field_utf8, reader_param.clone(), None); + reg_fn!(asm, write_field_binary, reader_param, None); + Ok(()) +} + +#[cfg(feature = "jit")] +fn gen_write_row( + schema: &Arc, + assembler: &Assembler, +) -> Result { + let mut builder = assembler + .new_func_builder("write_row") + .param("row", PTR) + .param("row_idx", I64) + .param("batch", PTR); + let mut b = builder.enter_block(); + for (i, f) in schema.fields().iter().enumerate() { + let dt = f.data_type(); + let arr = format!("a{}", i); + b.declare_as( + &arr, + b.call("column", vec![b.id("batch")?, b.lit_i(i as i64)])?, + )?; + if f.is_nullable() { + b.if_block( + |c| c.call("is_null", vec![c.id(&arr)?, c.id("row_idx")?]), + |t| { + t.call_stmt("set_null_at", vec![t.id("row")?, t.lit_i(i as i64)])?; + Ok(()) + }, + |e| { + e.call_stmt( + "set_non_null_at", + vec![e.id("row")?, e.lit_i(i as i64)], + )?; + let params = vec![ + e.id("row")?, + e.id(&arr)?, + e.lit_i(i as i64), + e.id("row_idx")?, + ]; + write_typed_field_stmt(dt, e, params)?; + Ok(()) + }, + )?; + } else { + b.call_stmt("set_non_null_at", vec![b.id("row")?, b.lit_i(i as i64)])?; + let params = vec![ + b.id("row")?, + b.id(&arr)?, + b.lit_i(i as i64), + b.id("row_idx")?, + ]; + write_typed_field_stmt(dt, &mut b, params)?; + } + } + Ok(b.build()) +} + +#[cfg(feature = "jit")] +fn write_typed_field_stmt<'a>( + dt: &DataType, + b: &mut CodeBlock<'a>, + params: Vec, +) -> Result<()> { + use DataType::*; + match dt { + Boolean => b.call_stmt("write_field_bool", params)?, + UInt8 => b.call_stmt("write_field_u8", params)?, + UInt16 => b.call_stmt("write_field_u16", params)?, + UInt32 => b.call_stmt("write_field_u32", params)?, + UInt64 => b.call_stmt("write_field_u64", params)?, + Int8 => b.call_stmt("write_field_i8", params)?, + Int16 => b.call_stmt("write_field_i16", params)?, + Int32 => b.call_stmt("write_field_i32", params)?, + Int64 => b.call_stmt("write_field_i64", params)?, + Float32 => b.call_stmt("write_field_f32", params)?, + Float64 => b.call_stmt("write_field_f64", params)?, + Date32 => b.call_stmt("write_field_date32", params)?, + Date64 => b.call_stmt("write_field_date64", params)?, + Utf8 => b.call_stmt("write_field_utf8", params)?, + Binary => b.call_stmt("write_field_binary", params)?, + _ => unimplemented!(), + } + Ok(()) +} + +macro_rules! fn_write_field { + ($NATIVE: ident, $ARRAY: ident) => { + paste::item! { + fn [](to: &mut RowWriter, from: &Arc, col_idx: usize, row_idx: usize) { + let from = from + .as_any() + .downcast_ref::<$ARRAY>() + .unwrap(); + to.[](col_idx, from.value(row_idx)); + } + } + }; +} + +fn_write_field!(bool, BooleanArray); +fn_write_field!(u8, UInt8Array); +fn_write_field!(u16, UInt16Array); +fn_write_field!(u32, UInt32Array); +fn_write_field!(u64, UInt64Array); +fn_write_field!(i8, Int8Array); +fn_write_field!(i16, Int16Array); +fn_write_field!(i32, Int32Array); +fn_write_field!(i64, Int64Array); +fn_write_field!(f32, Float32Array); +fn_write_field!(f64, Float64Array); + +fn write_field_date32( + to: &mut RowWriter, + from: &Arc, + col_idx: usize, + row_idx: usize, +) { + let from = from.as_any().downcast_ref::().unwrap(); + to.set_date32(col_idx, from.value(row_idx)); +} + +fn write_field_date64( + to: &mut RowWriter, + from: &Arc, + col_idx: usize, + row_idx: usize, +) { + let from = from.as_any().downcast_ref::().unwrap(); + to.set_date64(col_idx, from.value(row_idx)); +} + +fn write_field_utf8( + to: &mut RowWriter, + from: &Arc, + col_idx: usize, + row_idx: usize, +) { + let from = from.as_any().downcast_ref::().unwrap(); + let s = from.value(row_idx); + let new_width = to.current_width() + s.as_bytes().len(); + if new_width > to.data.capacity() { + // double the capacity to avoid repeated resize + to.data.resize(max(to.data.capacity() * 2, new_width), 0); + } + to.set_utf8(col_idx, s); +} + +fn write_field_binary( + to: &mut RowWriter, + from: &Arc, + col_idx: usize, + row_idx: usize, +) { + let from = from.as_any().downcast_ref::().unwrap(); + let s = from.value(row_idx); + let new_width = to.current_width() + s.len(); + if new_width > to.data.capacity() { + // double the capacity to avoid repeated resize + to.data.resize(max(to.data.capacity() * 2, new_width), 0); + } + to.set_binary(col_idx, s); +} + fn write_field( col_idx: usize, row_idx: usize, @@ -251,82 +549,24 @@ fn write_field( dt: &DataType, row: &mut RowWriter, ) { - // TODO: JIT compile this - use arrow::array::*; use DataType::*; + row.set_non_null_at(col_idx); match dt { - Boolean => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_bool(col_idx, c.value(row_idx)); - } - UInt8 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_u8(col_idx, c.value(row_idx)); - } - UInt16 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_u16(col_idx, c.value(row_idx)); - } - UInt32 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_u32(col_idx, c.value(row_idx)); - } - UInt64 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_u64(col_idx, c.value(row_idx)); - } - Int8 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_i8(col_idx, c.value(row_idx)); - } - Int16 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_i16(col_idx, c.value(row_idx)); - } - Int32 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_i32(col_idx, c.value(row_idx)); - } - Int64 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_i64(col_idx, c.value(row_idx)); - } - Float32 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_f32(col_idx, c.value(row_idx)); - } - Float64 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_f64(col_idx, c.value(row_idx)); - } - Date32 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_date32(col_idx, c.value(row_idx)); - } - Date64 => { - let c = col.as_any().downcast_ref::().unwrap(); - row.set_date64(col_idx, c.value(row_idx)); - } - Utf8 => { - let c = col.as_any().downcast_ref::().unwrap(); - let s = c.value(row_idx); - let new_width = row.current_width() + s.as_bytes().len(); - if new_width > row.data.capacity() { - // double the capacity to avoid repeated resize - row.data.resize(max(row.data.capacity() * 2, new_width), 0); - } - row.set_utf8(col_idx, s); - } - Binary => { - let c = col.as_any().downcast_ref::().unwrap(); - let binary = c.value(row_idx); - let new_width = row.current_width() + binary.len(); - if new_width > row.data.capacity() { - // double the capacity to avoid repeated resize - row.data.resize(max(row.data.capacity() * 2, new_width), 0); - } - row.set_binary(col_idx, binary); - } + Boolean => write_field_bool(row, col, col_idx, row_idx), + UInt8 => write_field_u8(row, col, col_idx, row_idx), + UInt16 => write_field_u16(row, col, col_idx, row_idx), + UInt32 => write_field_u32(row, col, col_idx, row_idx), + UInt64 => write_field_u64(row, col, col_idx, row_idx), + Int8 => write_field_i8(row, col, col_idx, row_idx), + Int16 => write_field_i16(row, col, col_idx, row_idx), + Int32 => write_field_i32(row, col, col_idx, row_idx), + Int64 => write_field_i64(row, col, col_idx, row_idx), + Float32 => write_field_f32(row, col, col_idx, row_idx), + Float64 => write_field_f64(row, col, col_idx, row_idx), + Date32 => write_field_date32(row, col, col_idx, row_idx), + Date64 => write_field_date64(row, col, col_idx, row_idx), + Utf8 => write_field_utf8(row, col, col_idx, row_idx), + Binary => write_field_binary(row, col, col_idx, row_idx), _ => unimplemented!(), } }