diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs index d95f9ccc7ac5..7020985a733a 100644 --- a/datafusion/jit/src/api.rs +++ b/datafusion/jit/src/api.rs @@ -153,6 +153,7 @@ impl FunctionBuilder { } /// Add one more parameter to the function. + #[must_use] pub fn param(mut self, name: impl Into, ty: JITType) -> Self { let name = name.into(); assert!(!self.fields.back().unwrap().contains_key(&name)); @@ -163,6 +164,7 @@ impl FunctionBuilder { /// Set return type for the function. Functions are of `void` type by default if /// you do not set the return type. + #[must_use] pub fn ret(mut self, name: impl Into, ty: JITType) -> Self { let name = name.into(); assert!(!self.fields.back().unwrap().contains_key(&name)); @@ -604,6 +606,17 @@ impl<'a> CodeBlock<'a> { internal_err!("No func with the name {} exist", fn_name) } } + + /// Return the value pointed to by the ptr stored in `ptr` + pub fn load(&self, ptr: Expr, ty: JITType) -> Result { + Ok(Expr::Load(Box::new(ptr), ty)) + } + + /// Store the value in `value` to the address in `ptr` + pub fn store(&mut self, value: Expr, ptr: Expr) -> Result<()> { + self.stmts.push(Stmt::Store(Box::new(value), Box::new(ptr))); + Ok(()) + } } impl Display for GeneratedFunction { diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs index fd10a909e783..55731a650548 100644 --- a/datafusion/jit/src/ast.rs +++ b/datafusion/jit/src/ast.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::DataType; use cranelift::codegen::ir; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use std::fmt::{Display, Formatter}; @@ -32,6 +33,8 @@ pub enum Stmt { Call(String, Vec), /// declare a new variable of type Declare(String, JITType), + /// store value (the first expr) to an address (the second expr) + Store(Box, Box), } #[derive(Copy, Clone, Debug, PartialEq)] @@ -54,6 +57,8 @@ pub enum Expr { Binary(BinaryExpr), /// call function expression Call(String, Vec, JITType), + /// Load a value from pointer + Load(Box, JITType), } impl Expr { @@ -63,6 +68,7 @@ impl Expr { Expr::Identifier(_, ty) => *ty, Expr::Binary(bin) => bin.get_type(), Expr::Call(_, _, ty) => *ty, + Expr::Load(_, ty) => *ty, } } } @@ -174,19 +180,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr { let field = schema.field_from_column(col)?; let ty = field.data_type(); - let jit_type = match ty { - arrow::datatypes::DataType::Int64 => I64, - arrow::datatypes::DataType::Float32 => F32, - arrow::datatypes::DataType::Float64 => F64, - arrow::datatypes::DataType::Boolean => BOOL, - - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Compiling Expression with type {} not yet supported in JIT mode", - ty - ))) - } - }; + let jit_type = JITType::try_from(ty)?; Ok(Expr::Identifier(field.qualified_name(), jit_type)) } @@ -272,12 +266,28 @@ pub const R64: JITType = JITType { native: ir::types::R64, code: 0x7f, }; +pub const PTR_SIZE: usize = std::mem::size_of::(); /// The pointer type to use based on our currently target. -pub const PTR: JITType = if std::mem::size_of::() == 8 { - R64 -} else { - R32 -}; +pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 }; + +impl TryFrom<&DataType> for JITType { + type Error = DataFusionError; + + /// Try to convert DataFusion's [DataType] to [JITType] + fn try_from(df_type: &DataType) -> Result { + match df_type { + DataType::Int64 => Ok(I64), + DataType::Float32 => Ok(F32), + DataType::Float64 => Ok(F64), + DataType::Boolean => Ok(BOOL), + + _ => Err(DataFusionError::NotImplemented(format!( + "Compiling Expression with type {} not yet supported in JIT mode", + df_type + ))), + } + } +} impl Stmt { /// print the statement with indentation @@ -323,6 +333,9 @@ impl Stmt { Stmt::Declare(name, ty) => { writeln!(f, "{}let {}: {};", ident_str, name, ty) } + Stmt::Store(value, ptr) => { + writeln!(f, "{}*({}) = {}", ident_str, ptr, value) + } } } } @@ -352,6 +365,7 @@ impl Display for Expr { .join(", ") ) } + Expr::Load(ptr, _) => write!(f, "*({})", ptr,), } } } diff --git a/datafusion/jit/src/compile.rs b/datafusion/jit/src/compile.rs new file mode 100644 index 000000000000..4e68b52104c0 --- /dev/null +++ b/datafusion/jit/src/compile.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Compile DataFusion Expr to JIT'd function. + +use datafusion_common::Result; + +use crate::api::Assembler; +use crate::ast::{JITType, I32}; +use crate::{ + api::GeneratedFunction, + ast::{Expr as JITExpr, I64, PTR_SIZE}, +}; + +/// Wrap JIT Expr to array compute function. +pub fn build_calc_fn( + assembler: &Assembler, + jit_expr: JITExpr, + inputs: Vec<(String, JITType)>, + ret_type: JITType, +) -> Result { + // Alias pointer type. + // The raw pointer `R64` or `R32` is not compatible with integers. + const PTR_TYPE: JITType = if PTR_SIZE == 8 { I64 } else { I32 }; + + let mut builder = assembler.new_func_builder("calc_fn"); + // Declare in-param. + // Each input takes one position, following by a pointer to place result, + // and the last is the length of inputs/output arrays. + for (name, _) in &inputs { + builder = builder.param(format!("{}_array", name), PTR_TYPE); + } + let mut builder = builder.param("result", ret_type).param("len", I64); + + // Start build function body. + // It's loop that calculates the result one by one. + let mut fn_body = builder.enter_block(); + fn_body.declare_as("index", fn_body.lit_i(0))?; + fn_body.while_block( + |cond| cond.lt(cond.id("index")?, cond.id("len")?), + |w| { + w.declare_as("offset", w.mul(w.id("index")?, w.lit_i(PTR_SIZE as i64))?)?; + for (name, ty) in &inputs { + w.declare_as( + format!("{}_ptr", name), + w.add(w.id(format!("{}_array", name))?, w.id("offset")?)?, + )?; + w.declare_as(name, w.load(w.id(format!("{}_ptr", name))?, *ty)?)?; + } + w.declare_as("res_ptr", w.add(w.id("result")?, w.id("offset")?)?)?; + w.declare_as("res", jit_expr.clone())?; + w.store(w.id("res")?, w.id("res_ptr")?)?; + + w.assign("index", w.add(w.id("index")?, w.lit_i(1))?)?; + Ok(()) + }, + )?; + + let gen_func = fn_body.build(); + Ok(gen_func) +} + +#[cfg(test)] +mod test { + use std::{collections::HashMap, sync::Arc}; + + use arrow::{ + array::{Array, PrimitiveArray}, + datatypes::{DataType, Int64Type}, + }; + use datafusion_common::{DFField, DFSchema, DataFusionError}; + use datafusion_expr::Expr as DFExpr; + + use crate::ast::BinaryExpr; + + use super::*; + + fn run_df_expr( + df_expr: DFExpr, + schema: Arc, + lhs: PrimitiveArray, + rhs: PrimitiveArray, + ) -> Result> { + if lhs.null_count() != 0 || rhs.null_count() != 0 { + return Err(DataFusionError::NotImplemented( + "Computing on nullable array not yet supported".to_string(), + )); + } + if lhs.len() != rhs.len() { + return Err(DataFusionError::NotImplemented( + "Computing on different length arrays not yet supported".to_string(), + )); + } + + // translate DF Expr to JIT Expr + let input_fields = schema + .fields() + .iter() + .map(|field| { + Ok(( + field.qualified_name(), + JITType::try_from(field.data_type())?, + )) + }) + .collect::>>()?; + let jit_expr: JITExpr = (df_expr, schema).try_into()?; + + // allocate memory for calc result + let len = lhs.len(); + let result = vec![0i64; len]; + + // compile and run JIT code + let assembler = Assembler::default(); + let gen_func = build_calc_fn(&assembler, jit_expr, input_fields, I64)?; + let mut jit = assembler.create_jit(); + let code_ptr = jit.compile(gen_func)?; + let code_fn = unsafe { + core::mem::transmute::<_, fn(*const i64, *const i64, *const i64, i64) -> ()>( + code_ptr, + ) + }; + code_fn( + lhs.values().as_ptr(), + rhs.values().as_ptr(), + result.as_ptr(), + len as i64, + ); + + let result_array = PrimitiveArray::::from_iter(result); + Ok(result_array) + } + + #[test] + fn array_add() { + let array_a: PrimitiveArray = + PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); + let array_b: PrimitiveArray = + PrimitiveArray::from_iter_values((10..20).map(|x| x + 1)); + let expected = + arrow::compute::kernels::arithmetic::add(&array_a, &array_b).unwrap(); + + let df_expr = datafusion_expr::col("a") + datafusion_expr::col("b"); + let schema = Arc::new( + DFSchema::new_with_metadata( + vec![ + DFField::new(Some("table1"), "a", DataType::Int64, false), + DFField::new(Some("table1"), "b", DataType::Int64, false), + ], + HashMap::new(), + ) + .unwrap(), + ); + + let result = run_df_expr(df_expr, schema, array_a, array_b).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn calc_fn_builder() { + let expr = JITExpr::Binary(BinaryExpr::Add( + Box::new(JITExpr::Identifier("table1.a".to_string(), I64)), + Box::new(JITExpr::Identifier("table1.b".to_string(), I64)), + )); + let fields = vec![("table1.a".to_string(), I64), ("table1.b".to_string(), I64)]; + + let expected = r#"fn calc_fn_0(table1.a_array: i64, table1.b_array: i64, result: i64, len: i64) -> () { + let index: i64; + index = 0; + while index < len { + let offset: i64; + offset = index * 8; + let table1.a_ptr: i64; + table1.a_ptr = table1.a_array + offset; + let table1.a: i64; + table1.a = *(table1.a_ptr); + let table1.b_ptr: i64; + table1.b_ptr = table1.b_array + offset; + let table1.b: i64; + table1.b = *(table1.b_ptr); + let res_ptr: i64; + res_ptr = result + offset; + let res: i64; + res = table1.a + table1.b; + *(res_ptr) = res + index = index + 1; + } +}"#; + + let assembler = Assembler::default(); + let gen_func = build_calc_fn(&assembler, expr, fields, I64).unwrap(); + assert_eq!(format!("{}", &gen_func), expected); + } +} diff --git a/datafusion/jit/src/jit.rs b/datafusion/jit/src/jit.rs index 0460cc805d65..21b0d44fb0b5 100644 --- a/datafusion/jit/src/jit.rs +++ b/datafusion/jit/src/jit.rs @@ -263,6 +263,7 @@ impl<'a> FunctionTranslator<'a> { Ok(()) } Stmt::Declare(_, _) => Ok(()), + Stmt::Store(value, ptr) => self.translate_store(*ptr, *value), } } @@ -289,6 +290,7 @@ impl<'a> FunctionTranslator<'a> { } Expr::Binary(b) => self.translate_binary_expr(b), Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret), + Expr::Load(ptr, ty) => self.translate_deref(*ptr, ty), } } @@ -462,6 +464,18 @@ impl<'a> FunctionTranslator<'a> { Ok(()) } + fn translate_deref(&mut self, ptr: Expr, ty: JITType) -> Result { + let ptr = self.translate_expr(ptr)?; + Ok(self.builder.ins().load(ty.native, MemFlags::new(), ptr, 0)) + } + + fn translate_store(&mut self, ptr: Expr, value: Expr) -> Result<()> { + let ptr = self.translate_expr(ptr)?; + let value = self.translate_expr(value)?; + self.builder.ins().store(MemFlags::new(), value, ptr, 0); + Ok(()) + } + fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result { let lhs = self.translate_expr(lhs)?; let rhs = self.translate_expr(rhs)?; diff --git a/datafusion/jit/src/lib.rs b/datafusion/jit/src/lib.rs index dff27da317e4..377d32d8a37d 100644 --- a/datafusion/jit/src/lib.rs +++ b/datafusion/jit/src/lib.rs @@ -19,6 +19,7 @@ pub mod api; pub mod ast; +pub mod compile; pub mod jit; #[cfg(test)] diff --git a/datafusion/row/src/lib.rs b/datafusion/row/src/lib.rs index c05cbcd0ef1c..d77c37063e92 100644 --- a/datafusion/row/src/lib.rs +++ b/datafusion/row/src/lib.rs @@ -30,10 +30,12 @@ //! we append their actual content to the end of the var length region and //! store their offset relative to row base and their length, packed into an 8-byte word. //! +//! ```plaintext //! ┌────────────────┬──────────────────────────┬───────────────────────┐ ┌───────────────────────┬────────────┐ //! │Validity Bitmask│ Fixed Width Field │ Variable Width Field │ ... │ vardata area │ padding │ //! │ (byte aligned) │ (native type width) │(vardata offset + len) │ │ (variable length) │ bytes │ //! └────────────────┴──────────────────────────┴───────────────────────┘ └───────────────────────┴────────────┘ +//! ``` //! //! For example, given the schema (Int8, Utf8, Float32, Utf8) //! @@ -41,10 +43,12 @@ //! //! Requires 32 bytes (31 bytes payload and 1 byte padding to make each tuple 8-bytes aligned): //! +//! ```plaintext //! ┌──────────┬──────────┬──────────────────────┬──────────────┬──────────────────────┬───────────────────────┬──────────┐ //! │0b00001011│ 0x01 │0x00000016 0x00000006│ 0x00000000 │0x0000001C 0x00000003│ FooBarbaz │ 0x00 │ //! └──────────┴──────────┴──────────────────────┴──────────────┴──────────────────────┴───────────────────────┴──────────┘ //! 0 1 2 10 14 22 31 32 +//! ``` //! use arrow::array::{make_builder, ArrayBuilder, ArrayRef};