Skip to content

Commit

Permalink
Add TableProvider::insert_into into FFI Bindings (#14391)
Browse files Browse the repository at this point in the history
* Wrap TableProvider::insert_into

This method was missing from the FFI bindings for use in
datafusion-python extensions.

* Switch from passing the runtime around to it's handle, add the handle of the current async function to the foreign execution plan on the insert_into operation

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
Co-authored-by: Tim Saucer <timsaucer@gmail.com>
  • Loading branch information
3 people authored Feb 1, 2025
1 parent a0d42ed commit 7f9a8c0
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 18 deletions.
8 changes: 4 additions & 4 deletions datafusion/ffi/src/execution_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion::{
execution::{SendableRecordBatchStream, TaskContext},
physical_plan::{DisplayAs, ExecutionPlan, PlanProperties},
};
use tokio::runtime::Runtime;
use tokio::runtime::Handle;

use crate::{
plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream,
Expand Down Expand Up @@ -72,7 +72,7 @@ unsafe impl Sync for FFI_ExecutionPlan {}
pub struct ExecutionPlanPrivateData {
pub plan: Arc<dyn ExecutionPlan>,
pub context: Arc<TaskContext>,
pub runtime: Option<Arc<Runtime>>,
pub runtime: Option<Handle>,
}

unsafe extern "C" fn properties_fn_wrapper(
Expand Down Expand Up @@ -110,7 +110,7 @@ unsafe extern "C" fn execute_fn_wrapper(
let private_data = plan.private_data as *const ExecutionPlanPrivateData;
let plan = &(*private_data).plan;
let ctx = &(*private_data).context;
let runtime = (*private_data).runtime.as_ref().map(Arc::clone);
let runtime = (*private_data).runtime.clone();

match plan.execute(partition, Arc::clone(ctx)) {
Ok(rbs) => RResult::ROk(FFI_RecordBatchStream::new(rbs, runtime)),
Expand Down Expand Up @@ -153,7 +153,7 @@ impl FFI_ExecutionPlan {
pub fn new(
plan: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
runtime: Option<Arc<Runtime>>,
runtime: Option<Handle>,
) -> Self {
let private_data = Box::new(ExecutionPlanPrivateData {
plan,
Expand Down
49 changes: 49 additions & 0 deletions datafusion/ffi/src/insert_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use abi_stable::StableAbi;
use datafusion::logical_expr::logical_plan::dml::InsertOp;

/// FFI safe version of [`InsertOp`].
#[repr(C)]
#[derive(StableAbi)]
#[allow(non_camel_case_types)]
pub enum FFI_InsertOp {
Append,
Overwrite,
Replace,
}

impl From<FFI_InsertOp> for InsertOp {
fn from(value: FFI_InsertOp) -> Self {
match value {
FFI_InsertOp::Append => InsertOp::Append,
FFI_InsertOp::Overwrite => InsertOp::Overwrite,
FFI_InsertOp::Replace => InsertOp::Replace,
}
}
}

impl From<InsertOp> for FFI_InsertOp {
fn from(value: InsertOp) -> Self {
match value {
InsertOp::Append => FFI_InsertOp::Append,
InsertOp::Overwrite => FFI_InsertOp::Overwrite,
InsertOp::Replace => FFI_InsertOp::Replace,
}
}
}
1 change: 1 addition & 0 deletions datafusion/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

pub mod arrow_wrappers;
pub mod execution_plan;
pub mod insert_op;
pub mod plan_properties;
pub mod record_batch_stream;
pub mod session_config;
Expand Down
8 changes: 4 additions & 4 deletions datafusion/ffi/src/record_batch_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{ffi::c_void, sync::Arc, task::Poll};
use std::{ffi::c_void, task::Poll};

use abi_stable::{
std_types::{ROption, RResult, RString},
Expand All @@ -33,7 +33,7 @@ use datafusion::{
execution::{RecordBatchStream, SendableRecordBatchStream},
};
use futures::{Stream, TryStreamExt};
use tokio::runtime::Runtime;
use tokio::runtime::Handle;

use crate::arrow_wrappers::{WrappedArray, WrappedSchema};

Expand Down Expand Up @@ -61,7 +61,7 @@ pub struct FFI_RecordBatchStream {

pub struct RecordBatchStreamPrivateData {
pub rbs: SendableRecordBatchStream,
pub runtime: Option<Arc<Runtime>>,
pub runtime: Option<Handle>,
}

impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
Expand All @@ -71,7 +71,7 @@ impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
}

impl FFI_RecordBatchStream {
pub fn new(stream: SendableRecordBatchStream, runtime: Option<Arc<Runtime>>) -> Self {
pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
rbs: stream,
runtime,
Expand Down
151 changes: 145 additions & 6 deletions datafusion/ffi/src/table_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use datafusion::{
catalog::{Session, TableProvider},
datasource::TableType,
error::DataFusionError,
execution::session_state::SessionStateBuilder,
logical_expr::TableProviderFilterPushDown,
execution::{session_state::SessionStateBuilder, TaskContext},
logical_expr::{logical_plan::dml::InsertOp, TableProviderFilterPushDown},
physical_plan::ExecutionPlan,
prelude::{Expr, SessionContext},
};
Expand All @@ -40,7 +40,7 @@ use datafusion_proto::{
protobuf::LogicalExprList,
};
use prost::Message;
use tokio::runtime::Runtime;
use tokio::runtime::Handle;

use crate::{
arrow_wrappers::WrappedSchema,
Expand All @@ -50,6 +50,7 @@ use crate::{

use super::{
execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan},
insert_op::FFI_InsertOp,
session_config::FFI_SessionConfig,
};
use datafusion::error::Result;
Expand Down Expand Up @@ -133,6 +134,14 @@ pub struct FFI_TableProvider {
-> RResult<RVec<FFI_TableProviderFilterPushDown>, RString>,
>,

pub insert_into:
unsafe extern "C" fn(
provider: &Self,
session_config: &FFI_SessionConfig,
input: &FFI_ExecutionPlan,
insert_op: FFI_InsertOp,
) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>>,

/// Used to create a clone on the provider of the execution plan. This should
/// only need to be called by the receiver of the plan.
pub clone: unsafe extern "C" fn(plan: &Self) -> Self,
Expand All @@ -153,7 +162,7 @@ unsafe impl Sync for FFI_TableProvider {}

struct ProviderPrivateData {
provider: Arc<dyn TableProvider + Send>,
runtime: Option<Arc<Runtime>>,
runtime: Option<Handle>,
}

unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema {
Expand Down Expand Up @@ -276,6 +285,53 @@ unsafe extern "C" fn scan_fn_wrapper(
.into_ffi()
}

unsafe extern "C" fn insert_into_fn_wrapper(
provider: &FFI_TableProvider,
session_config: &FFI_SessionConfig,
input: &FFI_ExecutionPlan,
insert_op: FFI_InsertOp,
) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>> {
let private_data = provider.private_data as *mut ProviderPrivateData;
let internal_provider = &(*private_data).provider;
let session_config = session_config.clone();
let input = input.clone();
let runtime = &(*private_data).runtime;

async move {
let config = match ForeignSessionConfig::try_from(&session_config) {
Ok(c) => c,
Err(e) => return RResult::RErr(e.to_string().into()),
};
let session = SessionStateBuilder::new()
.with_default_features()
.with_config(config.0)
.build();
let ctx = SessionContext::new_with_state(session);

let input = match ForeignExecutionPlan::try_from(&input) {
Ok(input) => Arc::new(input),
Err(e) => return RResult::RErr(e.to_string().into()),
};

let insert_op = InsertOp::from(insert_op);

let plan = match internal_provider
.insert_into(&ctx.state(), input, insert_op)
.await
{
Ok(p) => p,
Err(e) => return RResult::RErr(e.to_string().into()),
};

RResult::ROk(FFI_ExecutionPlan::new(
plan,
ctx.task_ctx(),
runtime.clone(),
))
}
.into_ffi()
}

unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) {
let private_data = Box::from_raw(provider.private_data as *mut ProviderPrivateData);
drop(private_data);
Expand All @@ -295,6 +351,7 @@ unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_Table
scan: scan_fn_wrapper,
table_type: table_type_fn_wrapper,
supports_filters_pushdown: provider.supports_filters_pushdown,
insert_into: provider.insert_into,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
version: super::version,
Expand All @@ -313,7 +370,7 @@ impl FFI_TableProvider {
pub fn new(
provider: Arc<dyn TableProvider + Send>,
can_support_pushdown_filters: bool,
runtime: Option<Arc<Runtime>>,
runtime: Option<Handle>,
) -> Self {
let private_data = Box::new(ProviderPrivateData { provider, runtime });

Expand All @@ -325,6 +382,7 @@ impl FFI_TableProvider {
true => Some(supports_filters_pushdown_fn_wrapper),
false => None,
},
insert_into: insert_into_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
version: super::version,
Expand Down Expand Up @@ -443,6 +501,37 @@ impl TableProvider for ForeignTableProvider {
}
}
}

async fn insert_into(
&self,
session: &dyn Session,
input: Arc<dyn ExecutionPlan>,
insert_op: InsertOp,
) -> Result<Arc<dyn ExecutionPlan>> {
let session_config: FFI_SessionConfig = session.config().into();

let rc = Handle::try_current().ok();
let input =
FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc);
let insert_op: FFI_InsertOp = insert_op.into();

let plan = unsafe {
let maybe_plan =
(self.0.insert_into)(&self.0, &session_config, &input, insert_op).await;

match maybe_plan {
RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?,
RResult::RErr(e) => {
return Err(DataFusionError::Internal(format!(
"Unable to perform insert_into via FFI: {}",
e
)))
}
}
};

Ok(Arc::new(plan))
}
}

#[cfg(test)]
Expand All @@ -453,7 +542,7 @@ mod tests {
use super::*;

#[tokio::test]
async fn test_round_trip_ffi_table_provider() -> Result<()> {
async fn test_round_trip_ffi_table_provider_scan() -> Result<()> {
use arrow::datatypes::Field;
use datafusion::arrow::{
array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
Expand Down Expand Up @@ -493,4 +582,54 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> {
use arrow::datatypes::Field;
use datafusion::arrow::{
array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
};
use datafusion::datasource::MemTable;

let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![64.0]))],
)?;

let ctx = SessionContext::new();

let provider =
Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?);

let ffi_provider = FFI_TableProvider::new(provider, true, None);

let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into();

ctx.register_table("t", Arc::new(foreign_table_provider))?;

let result = ctx
.sql("INSERT INTO t VALUES (128.0);")
.await?
.collect()
.await?;

assert!(result.len() == 1 && result[0].num_rows() == 1);

ctx.table("t")
.await?
.select(vec![col("a")])?
.filter(col("a").gt(lit(3.0)))?
.show()
.await?;

Ok(())
}
}
8 changes: 4 additions & 4 deletions datafusion/ffi/src/tests/async_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion::{
};
use futures::Stream;
use tokio::{
runtime::Runtime,
runtime::Handle,
sync::{broadcast, mpsc},
};

Expand All @@ -59,7 +59,7 @@ fn async_table_provider_thread(
mut shutdown: mpsc::Receiver<bool>,
mut batch_request: mpsc::Receiver<bool>,
batch_sender: broadcast::Sender<Option<RecordBatch>>,
tokio_rt: mpsc::Sender<Arc<Runtime>>,
tokio_rt: mpsc::Sender<Handle>,
) {
let runtime = Arc::new(
tokio::runtime::Builder::new_current_thread()
Expand All @@ -68,7 +68,7 @@ fn async_table_provider_thread(
);
let _runtime_guard = runtime.enter();
tokio_rt
.blocking_send(Arc::clone(&runtime))
.blocking_send(runtime.handle().clone())
.expect("Unable to send tokio runtime back to main thread");

runtime.block_on(async move {
Expand All @@ -91,7 +91,7 @@ fn async_table_provider_thread(
let _ = shutdown.blocking_recv();
}

pub fn start_async_provider() -> (AsyncTableProvider, Arc<Runtime>) {
pub fn start_async_provider() -> (AsyncTableProvider, Handle) {
let (batch_request_tx, batch_request_rx) = mpsc::channel(10);
let (record_batch_tx, record_batch_rx) = broadcast::channel(10);
let (tokio_rt_tx, mut tokio_rt_rx) = mpsc::channel(10);
Expand Down

0 comments on commit 7f9a8c0

Please sign in to comment.