Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate JIT'd expression over arrays #2587

Merged
merged 5 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions datafusion/jit/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ impl FunctionBuilder {
}

/// Add one more parameter to the function.
#[must_use]
pub fn param(mut self, name: impl Into<String>, ty: JITType) -> Self {
let name = name.into();
assert!(!self.fields.back().unwrap().contains_key(&name));
Expand All @@ -163,6 +164,7 @@ impl FunctionBuilder {

/// Set return type for the function. Functions are of `void` type by default if
/// you do not set the return type.
#[must_use]
pub fn ret(mut self, name: impl Into<String>, ty: JITType) -> Self {
let name = name.into();
assert!(!self.fields.back().unwrap().contains_key(&name));
Expand Down Expand Up @@ -604,6 +606,17 @@ impl<'a> CodeBlock<'a> {
internal_err!("No func with the name {} exist", fn_name)
}
}

/// Return the value pointed to by the ptr stored in `ptr`
pub fn load(&self, ptr: Expr, ty: JITType) -> Result<Expr> {
Ok(Expr::Load(Box::new(ptr), ty))
}

/// Store the value in `value` to the address in `ptr`
pub fn store(&mut self, value: Expr, ptr: Expr) -> Result<()> {
waynexia marked this conversation as resolved.
Show resolved Hide resolved
self.stmts.push(Stmt::Store(Box::new(value), Box::new(ptr)));
Ok(())
}
}

impl Display for GeneratedFunction {
Expand Down
50 changes: 32 additions & 18 deletions datafusion/jit/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use cranelift::codegen::ir;
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
use std::fmt::{Display, Formatter};
Expand All @@ -32,6 +33,8 @@ pub enum Stmt {
Call(String, Vec<Expr>),
/// declare a new variable of type
Declare(String, JITType),
/// store value (the first expr) to an address (the second expr)
Store(Box<Expr>, Box<Expr>),
}

#[derive(Copy, Clone, Debug, PartialEq)]
Expand All @@ -54,6 +57,8 @@ pub enum Expr {
Binary(BinaryExpr),
/// call function expression
Call(String, Vec<Expr>, JITType),
/// Load a value from pointer
Load(Box<Expr>, JITType),
}

impl Expr {
Expand All @@ -63,6 +68,7 @@ impl Expr {
Expr::Identifier(_, ty) => *ty,
Expr::Binary(bin) => bin.get_type(),
Expr::Call(_, _, ty) => *ty,
Expr::Load(_, ty) => *ty,
}
}
}
Expand Down Expand Up @@ -174,19 +180,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr {
let field = schema.field_from_column(col)?;
let ty = field.data_type();

let jit_type = match ty {
arrow::datatypes::DataType::Int64 => I64,
arrow::datatypes::DataType::Float32 => F32,
arrow::datatypes::DataType::Float64 => F64,
arrow::datatypes::DataType::Boolean => BOOL,

_ => {
return Err(DataFusionError::NotImplemented(format!(
"Compiling Expression with type {} not yet supported in JIT mode",
ty
)))
}
};
let jit_type = JITType::try_from(ty)?;

Ok(Expr::Identifier(field.qualified_name(), jit_type))
}
Expand Down Expand Up @@ -272,12 +266,28 @@ pub const R64: JITType = JITType {
native: ir::types::R64,
code: 0x7f,
};
pub const PTR_SIZE: usize = std::mem::size_of::<usize>();
/// The pointer type to use based on our currently target.
pub const PTR: JITType = if std::mem::size_of::<usize>() == 8 {
R64
} else {
R32
};
pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 };

impl TryFrom<&DataType> for JITType {
type Error = DataFusionError;

/// Try to convert DataFusion's [DataType] to [JITType]
fn try_from(df_type: &DataType) -> Result<Self, Self::Error> {
match df_type {
DataType::Int64 => Ok(I64),
DataType::Float32 => Ok(F32),
DataType::Float64 => Ok(F64),
DataType::Boolean => Ok(BOOL),

_ => Err(DataFusionError::NotImplemented(format!(
"Compiling Expression with type {} not yet supported in JIT mode",
df_type
))),
}
}
}

impl Stmt {
/// print the statement with indentation
Expand Down Expand Up @@ -323,6 +333,9 @@ impl Stmt {
Stmt::Declare(name, ty) => {
writeln!(f, "{}let {}: {};", ident_str, name, ty)
}
Stmt::Store(value, ptr) => {
writeln!(f, "{}*({}) = {}", ident_str, ptr, value)
}
}
}
}
Expand Down Expand Up @@ -352,6 +365,7 @@ impl Display for Expr {
.join(", ")
)
}
Expr::Load(ptr, _) => write!(f, "*({})", ptr,),
}
}
}
Expand Down
207 changes: 207 additions & 0 deletions datafusion/jit/src/compile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Compile DataFusion Expr to JIT'd function.

use datafusion_common::Result;

use crate::api::Assembler;
use crate::ast::{JITType, I32};
use crate::{
api::GeneratedFunction,
ast::{Expr as JITExpr, I64, PTR_SIZE},
};

/// Wrap JIT Expr to array compute function.
pub fn build_calc_fn(
assembler: &Assembler,
jit_expr: JITExpr,
inputs: Vec<(String, JITType)>,
ret_type: JITType,
) -> Result<GeneratedFunction> {
// Alias pointer type.
// The raw pointer `R64` or `R32` is not compatible with integers.
const PTR_TYPE: JITType = if PTR_SIZE == 8 { I64 } else { I32 };

let mut builder = assembler.new_func_builder("calc_fn");
// Declare in-param.
// Each input takes one position, following by a pointer to place result,
// and the last is the length of inputs/output arrays.
for (name, _) in &inputs {
builder = builder.param(format!("{}_array", name), PTR_TYPE);
}
let mut builder = builder.param("result", ret_type).param("len", I64);

// Start build function body.
// It's loop that calculates the result one by one.
let mut fn_body = builder.enter_block();
fn_body.declare_as("index", fn_body.lit_i(0))?;
fn_body.while_block(
|cond| cond.lt(cond.id("index")?, cond.id("len")?),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do we need sanity check for equal array lengths?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated code is working on pointer directly, it might be hard to do these check inside it. I check inputs' length before pass them to generated code at here.

        if lhs.len() != rhs.len() {
            return Err(DataFusionError::NotImplemented(
                "Computing on different length arrays not yet supported".to_string(),
            ));
        }

But I agree that we should consider how to improve safety of generated code when its logic get complicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general the idea is that all the safety checks are done during JIT generation as suggested by @waynexia

|w| {
w.declare_as("offset", w.mul(w.id("index")?, w.lit_i(PTR_SIZE as i64))?)?;
for (name, ty) in &inputs {
w.declare_as(
format!("{}_ptr", name),
w.add(w.id(format!("{}_array", name))?, w.id("offset")?)?,
)?;
w.declare_as(name, w.load(w.id(format!("{}_ptr", name))?, *ty)?)?;
}
w.declare_as("res_ptr", w.add(w.id("result")?, w.id("offset")?)?)?;
w.declare_as("res", jit_expr.clone())?;
w.store(w.id("res")?, w.id("res_ptr")?)?;

w.assign("index", w.add(w.id("index")?, w.lit_i(1))?)?;
Ok(())
},
)?;

let gen_func = fn_body.build();
Ok(gen_func)
}

#[cfg(test)]
mod test {
use std::{collections::HashMap, sync::Arc};

use arrow::{
array::{Array, PrimitiveArray},
datatypes::{DataType, Int64Type},
};
use datafusion_common::{DFField, DFSchema, DataFusionError};
use datafusion_expr::Expr as DFExpr;

use crate::ast::BinaryExpr;

use super::*;

fn run_df_expr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the longer term I would like to see this type of logic encapsulated somehow

So we would have a function or struct that took an Expr and several ArrayRefs and then called a JIT or non-JIT version of evaluation depending on flags or options.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely!

I'm going to figure out how to play with var length param list and type casting, and then encapsulate a plan node level interface (maybe take LogicalPlan and data as input, haven't thought in detail) as next step.

df_expr: DFExpr,
schema: Arc<DFSchema>,
lhs: PrimitiveArray<Int64Type>,
rhs: PrimitiveArray<Int64Type>,
) -> Result<PrimitiveArray<Int64Type>> {
if lhs.null_count() != 0 || rhs.null_count() != 0 {
return Err(DataFusionError::NotImplemented(
"Computing on nullable array not yet supported".to_string(),
));
}
if lhs.len() != rhs.len() {
return Err(DataFusionError::NotImplemented(
"Computing on different length arrays not yet supported".to_string(),
));
}

// translate DF Expr to JIT Expr
let input_fields = schema
.fields()
.iter()
.map(|field| {
Ok((
field.qualified_name(),
JITType::try_from(field.data_type())?,
))
})
.collect::<Result<Vec<_>>>()?;
let jit_expr: JITExpr = (df_expr, schema).try_into()?;

// allocate memory for calc result
let len = lhs.len();
let result = vec![0i64; len];

// compile and run JIT code
let assembler = Assembler::default();
let gen_func = build_calc_fn(&assembler, jit_expr, input_fields, I64)?;
let mut jit = assembler.create_jit();
let code_ptr = jit.compile(gen_func)?;
let code_fn = unsafe {
core::mem::transmute::<_, fn(*const i64, *const i64, *const i64, i64) -> ()>(
code_ptr,
)
};
code_fn(
lhs.values().as_ptr(),
rhs.values().as_ptr(),
result.as_ptr(),
len as i64,
);

let result_array = PrimitiveArray::<Int64Type>::from_iter(result);
Ok(result_array)
}

#[test]
fn array_add() {
let array_a: PrimitiveArray<Int64Type> =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using different values for array_a and array_b so issues in argument handling would be evident

Like maybe

        let array_b: PrimitiveArray<Int64Type> =
            PrimitiveArray::from_iter_values((10..20).map(|x| x + 1));

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch 😆

PrimitiveArray::from_iter_values((0..10).map(|x| x + 1));
let array_b: PrimitiveArray<Int64Type> =
PrimitiveArray::from_iter_values((10..20).map(|x| x + 1));
let expected =
arrow::compute::kernels::arithmetic::add(&array_a, &array_b).unwrap();

let df_expr = datafusion_expr::col("a") + datafusion_expr::col("b");
let schema = Arc::new(
DFSchema::new_with_metadata(
vec![
DFField::new(Some("table1"), "a", DataType::Int64, false),
DFField::new(Some("table1"), "b", DataType::Int64, false),
],
HashMap::new(),
)
.unwrap(),
);

let result = run_df_expr(df_expr, schema, array_a, array_b).unwrap();
assert_eq!(result, expected);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoo, really nice 🎉

}

#[test]
fn calc_fn_builder() {
let expr = JITExpr::Binary(BinaryExpr::Add(
Box::new(JITExpr::Identifier("table1.a".to_string(), I64)),
Box::new(JITExpr::Identifier("table1.b".to_string(), I64)),
));
let fields = vec![("table1.a".to_string(), I64), ("table1.b".to_string(), I64)];

let expected = r#"fn calc_fn_0(table1.a_array: i64, table1.b_array: i64, result: i64, len: i64) -> () {
let index: i64;
index = 0;
while index < len {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at this code and it looks 👍 to me

let offset: i64;
offset = index * 8;
let table1.a_ptr: i64;
table1.a_ptr = table1.a_array + offset;
let table1.a: i64;
table1.a = *(table1.a_ptr);
let table1.b_ptr: i64;
table1.b_ptr = table1.b_array + offset;
let table1.b: i64;
table1.b = *(table1.b_ptr);
let res_ptr: i64;
res_ptr = result + offset;
let res: i64;
res = table1.a + table1.b;
*(res_ptr) = res
index = index + 1;
}
}"#;

let assembler = Assembler::default();
let gen_func = build_calc_fn(&assembler, expr, fields, I64).unwrap();
assert_eq!(format!("{}", &gen_func), expected);
}
}
14 changes: 14 additions & 0 deletions datafusion/jit/src/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ impl<'a> FunctionTranslator<'a> {
Ok(())
}
Stmt::Declare(_, _) => Ok(()),
Stmt::Store(value, ptr) => self.translate_store(*ptr, *value),
}
}

Expand All @@ -289,6 +290,7 @@ impl<'a> FunctionTranslator<'a> {
}
Expr::Binary(b) => self.translate_binary_expr(b),
Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret),
Expr::Load(ptr, ty) => self.translate_deref(*ptr, ty),
}
}

Expand Down Expand Up @@ -462,6 +464,18 @@ impl<'a> FunctionTranslator<'a> {
Ok(())
}

fn translate_deref(&mut self, ptr: Expr, ty: JITType) -> Result<Value> {
let ptr = self.translate_expr(ptr)?;
Ok(self.builder.ins().load(ty.native, MemFlags::new(), ptr, 0))
}

fn translate_store(&mut self, ptr: Expr, value: Expr) -> Result<()> {
let ptr = self.translate_expr(ptr)?;
let value = self.translate_expr(value)?;
self.builder.ins().store(MemFlags::new(), value, ptr, 0);
Ok(())
}

fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result<Value> {
let lhs = self.translate_expr(lhs)?;
let rhs = self.translate_expr(rhs)?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

pub mod api;
pub mod ast;
pub mod compile;
pub mod jit;

#[cfg(test)]
Expand Down
Loading