Skip to content

Commit

Permalink
Re-enable zero copy from pyarrow to arrow-rs (#264)
Browse files Browse the repository at this point in the history
Prior failing test (test_pre_transform_multi_partition) is passing now
  • Loading branch information
jonmmease authored Mar 18, 2023
1 parent 921b240 commit 8e6eb77
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 81 deletions.
25 changes: 11 additions & 14 deletions python/vegafusion/vegafusion/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pyarrow as pa
from typing import Union
from .connection import SqlConnection
from .transformer import import_pyarrow_interchange
from .transformer import import_pyarrow_interchange, to_arrow_table

try:
from duckdb import DuckDBPyConnection
Expand Down Expand Up @@ -109,10 +109,10 @@ def process_request_bytes(self, request):
# No grpc channel, get or initialize an embedded runtime
return self.embedded_runtime.process_request_bytes(request)

def _serialize_or_register_inline_datasets(self, inline_datasets=None):
def _arrowify_or_register_inline_datasets(self, inline_datasets=None):
from .transformer import to_arrow_ipc_bytes, arrow_table_to_ipc_bytes
inline_datasets = inline_datasets or dict()
inline_dataset_bytes = dict()
inline_arrow_datasets = dict()
for name, value in inline_datasets.items():
if isinstance(value, pa.Table):
if self._connection is not None:
Expand All @@ -123,8 +123,7 @@ def _serialize_or_register_inline_datasets(self, inline_datasets=None):
except ValueError:
pass

table_bytes = arrow_table_to_ipc_bytes(value, stream=True)
inline_dataset_bytes[name] = table_bytes
inline_arrow_datasets[name] = value
elif isinstance(value, pd.DataFrame):
if self._connection is not None:
try:
Expand All @@ -134,8 +133,7 @@ def _serialize_or_register_inline_datasets(self, inline_datasets=None):
except ValueError:
pass

table_bytes = to_arrow_ipc_bytes(value, stream=True)
inline_dataset_bytes[name] = table_bytes
inline_arrow_datasets[name] = to_arrow_table(value)
elif hasattr(value, "__dataframe__"):
pi = import_pyarrow_interchange()
value = pi.from_dataframe(value)
Expand All @@ -147,12 +145,11 @@ def _serialize_or_register_inline_datasets(self, inline_datasets=None):
except ValueError:
pass

table_bytes = arrow_table_to_ipc_bytes(value, stream=True)
inline_dataset_bytes[name] = table_bytes
inline_arrow_datasets[name] = value
else:
raise ValueError(f"Unsupported DataFrame type: {type(value)}")

return inline_dataset_bytes
return inline_arrow_datasets

def pre_transform_spec(
self,
Expand Down Expand Up @@ -202,15 +199,15 @@ def pre_transform_spec(
if self._grpc_channel:
raise ValueError("pre_transform_spec not yet supported over gRPC")
else:
inline_dataset_bytes = self._serialize_or_register_inline_datasets(inline_datasets)
inline_arrow_dataset = self._arrowify_or_register_inline_datasets(inline_datasets)
try:
new_spec, warnings = self.embedded_runtime.pre_transform_spec(
spec,
local_tz=local_tz,
default_input_tz=default_input_tz,
row_limit=row_limit,
preserve_interactivity=preserve_interactivity,
inline_datasets=inline_dataset_bytes
inline_datasets=inline_arrow_dataset
)
finally:
# Clean up temporary tables
Expand Down Expand Up @@ -267,15 +264,15 @@ def pre_transform_datasets(self, spec, datasets, local_tz, default_input_tz=None
raise ValueError(err_msg)

# Serialize inline datasets
inline_dataset_bytes = self._serialize_or_register_inline_datasets(inline_datasets)
inline_arrow_dataset = self._arrowify_or_register_inline_datasets(inline_datasets)
try:
values, warnings = self.embedded_runtime.pre_transform_datasets(
spec,
pre_tx_vars,
local_tz=local_tz,
default_input_tz=default_input_tz,
row_limit=row_limit,
inline_datasets=inline_dataset_bytes
inline_datasets=inline_arrow_dataset
)
finally:
# Clean up registered tables (both inline and internal temporary tables)
Expand Down
56 changes: 56 additions & 0 deletions vegafusion-common/src/data/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ use {
std::{borrow::Cow, convert::TryFrom},
};

#[cfg(feature = "pyarrow")]
use {
arrow::pyarrow::PyArrowConvert,
pyo3::{
prelude::PyModule,
types::{PyList, PyTuple},
PyAny, PyErr, PyObject, Python,
},
};

#[derive(Clone, Debug)]
pub struct VegaFusionTable {
pub schema: SchemaRef,
Expand Down Expand Up @@ -271,6 +281,52 @@ impl VegaFusionTable {
}
}

#[cfg(feature = "pyarrow")]
pub fn from_pyarrow(py: Python, pyarrow_table: &PyAny) -> std::result::Result<Self, PyErr> {
// Extract table.schema as a Rust Schema
let getattr_args = PyTuple::new(py, vec!["schema"]);
let schema_object = pyarrow_table.call_method1("__getattribute__", getattr_args)?;
let schema = Schema::from_pyarrow(schema_object)?;

// Extract table.to_batches() as a Rust Vec<RecordBatch>
let batches_object = pyarrow_table.call_method0("to_batches")?;
let batches_list = batches_object.downcast::<PyList>()?;
let batches = batches_list
.iter()
.map(|batch_any| Ok(RecordBatch::from_pyarrow(batch_any)?))
.collect::<Result<Vec<RecordBatch>>>()?;

Ok(VegaFusionTable::try_new(Arc::new(schema), batches)?)
}

#[cfg(feature = "pyarrow")]
pub fn to_pyarrow(&self, py: Python) -> std::result::Result<PyObject, PyErr> {
// Convert table's record batches into Python list of pyarrow batches
let pyarrow_module = PyModule::import(py, "pyarrow")?;
let table_cls = pyarrow_module.getattr("Table")?;
let batch_objects = self
.batches
.iter()
.map(|batch| Ok(batch.to_pyarrow(py)?))
.collect::<Result<Vec<_>>>()?;
let batches_list = PyList::new(py, batch_objects);

// Convert table's schema into pyarrow schema
let schema = if let Some(batch) = self.batches.get(0) {
// Get schema from first batch if present
batch.schema()
} else {
self.schema.clone()
};

let schema_object = schema.to_pyarrow(py)?;

// Build pyarrow table
let args = PyTuple::new(py, vec![batches_list.as_ref(), schema_object.as_ref(py)]);
let pa_table = table_cls.call_method1("from_batches", args)?;
Ok(PyObject::from(pa_table))
}

// Serialize to bytes using Arrow IPC format
pub fn to_ipc_bytes(&self) -> Result<Vec<u8>> {
let buffer: Vec<u8> = Vec::new();
Expand Down
58 changes: 5 additions & 53 deletions vegafusion-python-embed/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ use std::sync::Arc;

use arrow::pyarrow::PyArrowConvert;
use async_trait::async_trait;
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString, PyTuple};
use pyo3::types::{IntoPyDict, PyDict, PyString, PyTuple};
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_core::{
arrow::{datatypes::Schema, record_batch::RecordBatch},
error::Result,
};
use vegafusion_core::{arrow::datatypes::Schema, error::Result};
use vegafusion_sql::connection::datafusion_conn::DataFusionConnection;
use vegafusion_sql::connection::{Connection, SqlConnection};
use vegafusion_sql::dataframe::{CsvReadOptions, DataFrame, SqlDataFrame};
Expand Down Expand Up @@ -108,41 +105,12 @@ impl Connection for PySqlConnection {
let random_id = uuid::Uuid::new_v4().to_string().replace('-', "_");
let table_name = format!("arrow_{random_id}");
Python::with_gil(|py| -> std::result::Result<_, PyErr> {
// Convert table's record batches into Python list of pyarrow batches
let pyarrow_module = PyModule::import(py, "pyarrow")?;
let table_cls = pyarrow_module.getattr("Table")?;
let batch_objects = table
.batches
.iter()
.map(|batch| Ok(batch.to_pyarrow(py)?))
.collect::<Result<Vec<_>>>()?;
let batches_list = PyList::new(py, batch_objects);

// Convert table's schema into pyarrow schema
let schema = if let Some(batch) = table.batches.get(0) {
// Get schema from first batch if present
batch.schema()
} else {
table.schema.clone()
};

let schema_object = schema.to_pyarrow(py)?;

// Build pyarrow table
let args = PyTuple::new(py, vec![batches_list.as_ref(), schema_object.as_ref(py)]);
let pa_table = table_cls.call_method1("from_batches", args)?;
let pa_table = table.to_pyarrow(py)?;

// Register table with Python connection
let table_name_object = table_name.clone().into_py(py);
let is_temporary_object = true.into_py(py);
let args = PyTuple::new(
py,
vec![
table_name_object.as_ref(py),
pa_table,
is_temporary_object.as_ref(py),
],
);
let args = PyTuple::new(py, vec![table_name_object, pa_table, is_temporary_object]);
self.conn.call_method1(py, "register_arrow", args)?;
Ok(())
})?;
Expand Down Expand Up @@ -219,27 +187,11 @@ impl SqlConnection for PySqlConnection {
let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> {
let query_object = PyString::new(py, query);
let query_object = query_object.as_ref();

let schema_object = schema.to_pyarrow(py)?;
let schema_object = schema_object.as_ref(py);
let args = PyTuple::new(py, vec![query_object, schema_object]);

let table_object = self.conn.call_method(py, "fetch_query", args, None)?;

// Extract table.schema as a Rust Schema
let getattr_args = PyTuple::new(py, vec!["schema"]);
let schema_object = table_object.call_method1(py, "__getattribute__", getattr_args)?;
let schema = Schema::from_pyarrow(schema_object.as_ref(py))?;

// Extract table.to_batches() as a Rust Vec<RecordBatch>
let batches_object = table_object.call_method0(py, "to_batches")?;
let batches_list = batches_object.downcast::<PyList>(py)?;
let batches = batches_list
.iter()
.map(|batch_any| Ok(RecordBatch::from_pyarrow(batch_any)?))
.collect::<Result<Vec<RecordBatch>>>()?;

Ok(VegaFusionTable::try_new(Arc::new(schema), batches)?)
VegaFusionTable::from_pyarrow(py, table_object.as_ref(py))
})?;
Ok(table)
}
Expand Down
36 changes: 22 additions & 14 deletions vegafusion-python-embed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod connection;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList, PyString};
use pyo3::types::{PyBytes, PyDict, PyList, PyTuple};
use std::collections::HashMap;
use std::sync::{Arc, Once};
use tokio::runtime::Runtime;
Expand All @@ -15,6 +15,7 @@ use crate::connection::PySqlConnection;
use env_logger::{Builder, Target};
use pythonize::depythonize;
use serde::{Deserialize, Serialize};
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_core::proto::gen::tasks::Variable;
use vegafusion_core::spec::chart::ChartSpec;
use vegafusion_core::task_graph::graph::ScopedVariable;
Expand Down Expand Up @@ -48,20 +49,27 @@ struct PyVegaFusionRuntime {
tokio_runtime: Runtime,
}

fn deserialize_inline_datasets(
fn process_inline_datasets(
inline_datasets: Option<&PyDict>,
) -> PyResult<HashMap<String, VegaFusionDataset>> {
if let Some(inline_datasets) = inline_datasets {
inline_datasets
.iter()
.map(|(name, table_bytes)| {
let name = name.downcast::<PyString>()?;
let ipc_bytes = table_bytes.downcast::<PyBytes>()?;
let ipc_bytes = ipc_bytes.as_bytes();
let dataset = VegaFusionDataset::from_table_ipc_bytes(ipc_bytes)?;
Ok((name.to_string(), dataset))
})
.collect::<PyResult<HashMap<_, _>>>()
Python::with_gil(|py| -> PyResult<_> {
let pyarrow_module = PyModule::import(py, "builtins")?;
let id_fun = pyarrow_module.getattr("id")?;

inline_datasets
.iter()
.map(|(name, pyarrow_table)| {
// Use object id of the pyarrow table as dataset's fingerprint
let args = PyTuple::new(py, vec![pyarrow_table]);
let id_object = id_fun.call(args, None)?;
let id = id_object.extract::<u64>()?;
let table = VegaFusionTable::from_pyarrow(py, pyarrow_table)?;
let dataset = VegaFusionDataset::Table { table, hash: id };
Ok((name.to_string(), dataset))
})
.collect::<PyResult<HashMap<_, _>>>()
})
} else {
Ok(Default::default())
}
Expand Down Expand Up @@ -121,7 +129,7 @@ impl PyVegaFusionRuntime {
preserve_interactivity: Option<bool>,
inline_datasets: Option<&PyDict>,
) -> PyResult<(PyObject, PyObject)> {
let inline_datasets = deserialize_inline_datasets(inline_datasets)?;
let inline_datasets = process_inline_datasets(inline_datasets)?;
let spec = parse_json_spec(spec)?;
let preserve_interactivity = preserve_interactivity.unwrap_or(false);

Expand Down Expand Up @@ -181,7 +189,7 @@ impl PyVegaFusionRuntime {
row_limit: Option<u32>,
inline_datasets: Option<&PyDict>,
) -> PyResult<(PyObject, PyObject)> {
let inline_datasets = deserialize_inline_datasets(inline_datasets)?;
let inline_datasets = process_inline_datasets(inline_datasets)?;
let spec = parse_json_spec(spec)?;

// Build variables
Expand Down

0 comments on commit 8e6eb77

Please sign in to comment.