Skip to content

Commit

Permalink
feat: Pull based native execution
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 21, 2024
1 parent 424733e commit 37f17a5
Show file tree
Hide file tree
Showing 11 changed files with 405 additions and 316 deletions.
16 changes: 9 additions & 7 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ class NativeUtil {
* @param batch
* the input Comet columnar batch
* @return
* a list containing pairs of memory addresses in the format of (address of Arrow array,
* address of Arrow schema)
* a list containing number of rows + pairs of memory addresses in the format of (address of
* Arrow array, address of Arrow schema)
*/
def exportBatch(batch: ColumnarBatch): Array[Long] = {
val vectors = (0 until batch.numCols()).flatMap { index =>
val exportedVectors = mutable.ArrayBuffer.empty[Long]
exportedVectors += batch.numRows()

(0 until batch.numCols()).foreach { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
Expand All @@ -63,17 +66,16 @@ class NativeUtil {
arrowArray,
arrowSchema)

Seq((arrowArray, arrowSchema))
exportedVectors += arrowArray.memoryAddress()
exportedVectors += arrowSchema.memoryAddress()
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}

vectors.flatMap { pair =>
Seq(pair._1.memoryAddress(), pair._2.memoryAddress())
}.toArray
exportedVectors.toArray
}

/**
Expand Down
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ impl From<CometError> for DataFusionError {
}
}

impl From<CometError> for ExecutionError {
fn from(value: CometError) -> Self {
match value {
CometError::Execution { source } => source,
_ => ExecutionError::GeneralError(value.to_string()),
}
}
}

impl jni::errors::ToException for CometError {
fn to_exception(&self) -> Exception {
match self {
Expand Down
83 changes: 57 additions & 26 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_physical_expr::{
AggregateExpr, ScalarFunctionExpr,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};

use crate::{
Expand All @@ -70,7 +71,7 @@ use crate::{
operators::expand::CometExpandExec,
shuffle_writer::ShuffleWriterExec,
},
operators::{CopyExec, ExecutionError, InputBatch, ScanExec},
operators::{CopyExec, ExecutionError, ScanExec},
serde::to_arrow_datatype,
spark_expression,
spark_expression::{
Expand All @@ -88,6 +89,8 @@ type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>, ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>, ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError>;

pub const TEST_EXEC_CONTEXT_ID: i64 = -1;

/// The query planner for converting Spark query plans to DataFusion query plans.
pub struct PhysicalPlanner {
// The execution context id of this planner.
Expand All @@ -105,7 +108,7 @@ impl PhysicalPlanner {
pub fn new() -> Self {
let execution_props = ExecutionProps::new();
Self {
exec_context_id: -1,
exec_context_id: TEST_EXEC_CONTEXT_ID,
execution_props,
}
}
Expand Down Expand Up @@ -612,24 +615,28 @@ impl PhysicalPlanner {

/// Create a DataFusion physical plan from Spark physical plan.
///
/// Note that we need `input_batches` parameter because we need to know the exact schema (not
/// only data type but also dictionary-encoding) at `ScanExec`s. It is because some DataFusion
/// operators, e.g., `ProjectionExec`, gets child operator schema during initialization and
/// uses it later for `RecordBatch`. We may be able to get rid of it once `RecordBatch`
/// relaxes schema check.
/// `inputs` is a vector of input source IDs. It is used to create `ScanExec`s. Each `ScanExec`
/// will be assigned a unique ID from `inputs` and the ID will be used to identify the input
/// source at JNI API.
///
/// Note that `ScanExec` will pull initial input batch during initialization. It is because we
/// need to know the exact schema (not only data type but also dictionary-encoding) at
/// `ScanExec`s. It is because some DataFusion operators, e.g., `ProjectionExec`, gets child
/// operator schema during initialization and uses it later for `RecordBatch`. We may be
/// able to get rid of it once `RecordBatch` relaxes schema check.
///
/// Note that we return created `Scan`s which will be kept at JNI API. JNI calls will use it to
/// feed in new input batch from Spark JVM side.
pub fn create_plan<'a>(
&'a self,
spark_plan: &'a Operator,
input_batches: &mut Vec<InputBatch>,
inputs: &mut Vec<Arc<GlobalRef>>,
) -> Result<(Vec<ScanExec>, Arc<dyn ExecutionPlan>), ExecutionError> {
let children = &spark_plan.children;
match spark_plan.op_struct.as_ref().unwrap() {
OpStruct::Projection(project) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;
let exprs: PhyExprResult = project
.project_list
.iter()
Expand All @@ -643,15 +650,15 @@ impl PhysicalPlanner {
}
OpStruct::Filter(filter) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;
let predicate =
self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?;

Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?)))
}
OpStruct::HashAgg(agg) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let group_exprs: PhyExprResult = agg
.grouping_exprs
Expand Down Expand Up @@ -716,13 +723,13 @@ impl PhysicalPlanner {
OpStruct::Limit(limit) => {
assert!(children.len() == 1);
let num = limit.limit;
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

Ok((scans, Arc::new(LocalLimitExec::new(child, num as usize))))
}
OpStruct::Sort(sort) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = sort
.sort_orders
Expand All @@ -741,21 +748,32 @@ impl PhysicalPlanner {
}
OpStruct::Scan(scan) => {
let fields = scan.fields.iter().map(to_arrow_datatype).collect_vec();
if input_batches.is_empty() {

// If it is not test execution context for unit test, we should have at least one
// input source
if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() {
return Err(ExecutionError::GeneralError(
"No input batch for scan".to_string(),
"No input for scan".to_string(),
));
}
// Consumes the first input batch source for the scan
let input_batch = input_batches.remove(0);

// Consumes the first input source for the scan
let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID
&& inputs.is_empty()
{
// For unit test, we will set input batch to scan directly by `set_input_batch`.
None
} else {
Some(inputs.remove(0))
};

// The `ScanExec` operator will take actual arrays from Spark during execution
let scan = ScanExec::new(input_batch, fields);
let scan = ScanExec::new(self.exec_context_id, input_source, fields)?;
Ok((vec![scan.clone()], Arc::new(scan)))
}
OpStruct::ShuffleWriter(writer) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let partitioning = self
.create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?;
Expand All @@ -772,7 +790,7 @@ impl PhysicalPlanner {
}
OpStruct::Expand(expand) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let mut projections = vec![];
let mut projection = vec![];
Expand Down Expand Up @@ -805,6 +823,18 @@ impl PhysicalPlanner {
.collect();
let schema = Arc::new(Schema::new(fields));

// `Expand` operator keeps the input batch and expands it to multiple output
// batches. However, `ScanExec` will reuse input arrays for the next
// input batch. Therefore, we need to copy the input batch to avoid
// the data corruption. Note that we only need to copy the input batch
// if the child operator is `ScanExec`, because other operators after `ScanExec`
// will create new arrays for the output batch.
let child = if child.as_any().downcast_ref::<ScanExec>().is_some() {
Arc::new(CopyExec::new(child))
} else {
child
};

Ok((
scans,
Arc::new(CometExpandExec::new(projections, child, schema)),
Expand Down Expand Up @@ -997,9 +1027,9 @@ mod tests {
let values = Int32Array::from(vec![0, 1, 2, 3]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let mut input_batches = vec![input_batch];

let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap();
scans[0].set_input_batch(input_batch);

let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
Expand Down Expand Up @@ -1077,9 +1107,11 @@ mod tests {
let values = StringArray::from(vec!["foo", "bar", "hello", "comet"]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let mut input_batches = vec![input_batch];

let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap();

// Scan's schema is determined by the input batch, so we need to set it before execution.
scans[0].set_input_batch(input_batch);

let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
Expand Down Expand Up @@ -1147,8 +1179,7 @@ mod tests {
let op = create_filter(op_scan, 0);
let planner = PhysicalPlanner::new();

let mut input_batches = vec![InputBatch::EOF];
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap();

let scan = &mut scans[0];
scan.set_input_batch(InputBatch::EOF);
Expand Down
Loading

0 comments on commit 37f17a5

Please sign in to comment.