From 8e6eb77dde6611e319267caa12f04ff35c9454d0 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 18 Mar 2023 15:48:16 -0400 Subject: [PATCH] Re-enable zero copy from pyarrow to arrow-rs (#264) Prior failing test (test_pre_transform_multi_partition) is passing now --- python/vegafusion/vegafusion/runtime.py | 25 +++++----- vegafusion-common/src/data/table.rs | 56 ++++++++++++++++++++++ vegafusion-python-embed/src/connection.rs | 58 ++--------------------- vegafusion-python-embed/src/lib.rs | 36 ++++++++------ 4 files changed, 94 insertions(+), 81 deletions(-) diff --git a/python/vegafusion/vegafusion/runtime.py b/python/vegafusion/vegafusion/runtime.py index 95d6122f3..5c7ce3389 100644 --- a/python/vegafusion/vegafusion/runtime.py +++ b/python/vegafusion/vegafusion/runtime.py @@ -11,7 +11,7 @@ import pyarrow as pa from typing import Union from .connection import SqlConnection -from .transformer import import_pyarrow_interchange +from .transformer import import_pyarrow_interchange, to_arrow_table try: from duckdb import DuckDBPyConnection @@ -109,10 +109,10 @@ def process_request_bytes(self, request): # No grpc channel, get or initialize an embedded runtime return self.embedded_runtime.process_request_bytes(request) - def _serialize_or_register_inline_datasets(self, inline_datasets=None): + def _arrowify_or_register_inline_datasets(self, inline_datasets=None): from .transformer import to_arrow_ipc_bytes, arrow_table_to_ipc_bytes inline_datasets = inline_datasets or dict() - inline_dataset_bytes = dict() + inline_arrow_datasets = dict() for name, value in inline_datasets.items(): if isinstance(value, pa.Table): if self._connection is not None: @@ -123,8 +123,7 @@ def _serialize_or_register_inline_datasets(self, inline_datasets=None): except ValueError: pass - table_bytes = arrow_table_to_ipc_bytes(value, stream=True) - inline_dataset_bytes[name] = table_bytes + inline_arrow_datasets[name] = value elif isinstance(value, pd.DataFrame): if self._connection is not None: try: @@ -134,8 +133,7 @@ def _serialize_or_register_inline_datasets(self, inline_datasets=None): except ValueError: pass - table_bytes = to_arrow_ipc_bytes(value, stream=True) - inline_dataset_bytes[name] = table_bytes + inline_arrow_datasets[name] = to_arrow_table(value) elif hasattr(value, "__dataframe__"): pi = import_pyarrow_interchange() value = pi.from_dataframe(value) @@ -147,12 +145,11 @@ def _serialize_or_register_inline_datasets(self, inline_datasets=None): except ValueError: pass - table_bytes = arrow_table_to_ipc_bytes(value, stream=True) - inline_dataset_bytes[name] = table_bytes + inline_arrow_datasets[name] = value else: raise ValueError(f"Unsupported DataFrame type: {type(value)}") - return inline_dataset_bytes + return inline_arrow_datasets def pre_transform_spec( self, @@ -202,7 +199,7 @@ def pre_transform_spec( if self._grpc_channel: raise ValueError("pre_transform_spec not yet supported over gRPC") else: - inline_dataset_bytes = self._serialize_or_register_inline_datasets(inline_datasets) + inline_arrow_dataset = self._arrowify_or_register_inline_datasets(inline_datasets) try: new_spec, warnings = self.embedded_runtime.pre_transform_spec( spec, @@ -210,7 +207,7 @@ def pre_transform_spec( default_input_tz=default_input_tz, row_limit=row_limit, preserve_interactivity=preserve_interactivity, - inline_datasets=inline_dataset_bytes + inline_datasets=inline_arrow_dataset ) finally: # Clean up temporary tables @@ -267,7 +264,7 @@ def pre_transform_datasets(self, spec, datasets, local_tz, default_input_tz=None raise ValueError(err_msg) # Serialize inline datasets - inline_dataset_bytes = self._serialize_or_register_inline_datasets(inline_datasets) + inline_arrow_dataset = self._arrowify_or_register_inline_datasets(inline_datasets) try: values, warnings = self.embedded_runtime.pre_transform_datasets( spec, @@ -275,7 +272,7 @@ def pre_transform_datasets(self, spec, datasets, local_tz, default_input_tz=None local_tz=local_tz, default_input_tz=default_input_tz, row_limit=row_limit, - inline_datasets=inline_dataset_bytes + inline_datasets=inline_arrow_dataset ) finally: # Clean up registered tables (both inline and internal temporary tables) diff --git a/vegafusion-common/src/data/table.rs b/vegafusion-common/src/data/table.rs index 2397cacc9..b3696a7ba 100644 --- a/vegafusion-common/src/data/table.rs +++ b/vegafusion-common/src/data/table.rs @@ -29,6 +29,16 @@ use { std::{borrow::Cow, convert::TryFrom}, }; +#[cfg(feature = "pyarrow")] +use { + arrow::pyarrow::PyArrowConvert, + pyo3::{ + prelude::PyModule, + types::{PyList, PyTuple}, + PyAny, PyErr, PyObject, Python, + }, +}; + #[derive(Clone, Debug)] pub struct VegaFusionTable { pub schema: SchemaRef, @@ -271,6 +281,52 @@ impl VegaFusionTable { } } + #[cfg(feature = "pyarrow")] + pub fn from_pyarrow(py: Python, pyarrow_table: &PyAny) -> std::result::Result { + // Extract table.schema as a Rust Schema + let getattr_args = PyTuple::new(py, vec!["schema"]); + let schema_object = pyarrow_table.call_method1("__getattribute__", getattr_args)?; + let schema = Schema::from_pyarrow(schema_object)?; + + // Extract table.to_batches() as a Rust Vec + let batches_object = pyarrow_table.call_method0("to_batches")?; + let batches_list = batches_object.downcast::()?; + let batches = batches_list + .iter() + .map(|batch_any| Ok(RecordBatch::from_pyarrow(batch_any)?)) + .collect::>>()?; + + Ok(VegaFusionTable::try_new(Arc::new(schema), batches)?) + } + + #[cfg(feature = "pyarrow")] + pub fn to_pyarrow(&self, py: Python) -> std::result::Result { + // Convert table's record batches into Python list of pyarrow batches + let pyarrow_module = PyModule::import(py, "pyarrow")?; + let table_cls = pyarrow_module.getattr("Table")?; + let batch_objects = self + .batches + .iter() + .map(|batch| Ok(batch.to_pyarrow(py)?)) + .collect::>>()?; + let batches_list = PyList::new(py, batch_objects); + + // Convert table's schema into pyarrow schema + let schema = if let Some(batch) = self.batches.get(0) { + // Get schema from first batch if present + batch.schema() + } else { + self.schema.clone() + }; + + let schema_object = schema.to_pyarrow(py)?; + + // Build pyarrow table + let args = PyTuple::new(py, vec![batches_list.as_ref(), schema_object.as_ref(py)]); + let pa_table = table_cls.call_method1("from_batches", args)?; + Ok(PyObject::from(pa_table)) + } + // Serialize to bytes using Arrow IPC format pub fn to_ipc_bytes(&self) -> Result> { let buffer: Vec = Vec::new(); diff --git a/vegafusion-python-embed/src/connection.rs b/vegafusion-python-embed/src/connection.rs index fb418db39..e19ccd2e1 100644 --- a/vegafusion-python-embed/src/connection.rs +++ b/vegafusion-python-embed/src/connection.rs @@ -5,12 +5,9 @@ use std::sync::Arc; use arrow::pyarrow::PyArrowConvert; use async_trait::async_trait; -use pyo3::types::{IntoPyDict, PyDict, PyList, PyString, PyTuple}; +use pyo3::types::{IntoPyDict, PyDict, PyString, PyTuple}; use vegafusion_common::data::table::VegaFusionTable; -use vegafusion_core::{ - arrow::{datatypes::Schema, record_batch::RecordBatch}, - error::Result, -}; +use vegafusion_core::{arrow::datatypes::Schema, error::Result}; use vegafusion_sql::connection::datafusion_conn::DataFusionConnection; use vegafusion_sql::connection::{Connection, SqlConnection}; use vegafusion_sql::dataframe::{CsvReadOptions, DataFrame, SqlDataFrame}; @@ -108,41 +105,12 @@ impl Connection for PySqlConnection { let random_id = uuid::Uuid::new_v4().to_string().replace('-', "_"); let table_name = format!("arrow_{random_id}"); Python::with_gil(|py| -> std::result::Result<_, PyErr> { - // Convert table's record batches into Python list of pyarrow batches - let pyarrow_module = PyModule::import(py, "pyarrow")?; - let table_cls = pyarrow_module.getattr("Table")?; - let batch_objects = table - .batches - .iter() - .map(|batch| Ok(batch.to_pyarrow(py)?)) - .collect::>>()?; - let batches_list = PyList::new(py, batch_objects); - - // Convert table's schema into pyarrow schema - let schema = if let Some(batch) = table.batches.get(0) { - // Get schema from first batch if present - batch.schema() - } else { - table.schema.clone() - }; - - let schema_object = schema.to_pyarrow(py)?; - - // Build pyarrow table - let args = PyTuple::new(py, vec![batches_list.as_ref(), schema_object.as_ref(py)]); - let pa_table = table_cls.call_method1("from_batches", args)?; + let pa_table = table.to_pyarrow(py)?; // Register table with Python connection let table_name_object = table_name.clone().into_py(py); let is_temporary_object = true.into_py(py); - let args = PyTuple::new( - py, - vec![ - table_name_object.as_ref(py), - pa_table, - is_temporary_object.as_ref(py), - ], - ); + let args = PyTuple::new(py, vec![table_name_object, pa_table, is_temporary_object]); self.conn.call_method1(py, "register_arrow", args)?; Ok(()) })?; @@ -219,27 +187,11 @@ impl SqlConnection for PySqlConnection { let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let query_object = PyString::new(py, query); let query_object = query_object.as_ref(); - let schema_object = schema.to_pyarrow(py)?; let schema_object = schema_object.as_ref(py); let args = PyTuple::new(py, vec![query_object, schema_object]); - let table_object = self.conn.call_method(py, "fetch_query", args, None)?; - - // Extract table.schema as a Rust Schema - let getattr_args = PyTuple::new(py, vec!["schema"]); - let schema_object = table_object.call_method1(py, "__getattribute__", getattr_args)?; - let schema = Schema::from_pyarrow(schema_object.as_ref(py))?; - - // Extract table.to_batches() as a Rust Vec - let batches_object = table_object.call_method0(py, "to_batches")?; - let batches_list = batches_object.downcast::(py)?; - let batches = batches_list - .iter() - .map(|batch_any| Ok(RecordBatch::from_pyarrow(batch_any)?)) - .collect::>>()?; - - Ok(VegaFusionTable::try_new(Arc::new(schema), batches)?) + VegaFusionTable::from_pyarrow(py, table_object.as_ref(py)) })?; Ok(table) } diff --git a/vegafusion-python-embed/src/lib.rs b/vegafusion-python-embed/src/lib.rs index 8972ddda9..554016a8b 100644 --- a/vegafusion-python-embed/src/lib.rs +++ b/vegafusion-python-embed/src/lib.rs @@ -2,7 +2,7 @@ pub mod connection; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict, PyList, PyString}; +use pyo3::types::{PyBytes, PyDict, PyList, PyTuple}; use std::collections::HashMap; use std::sync::{Arc, Once}; use tokio::runtime::Runtime; @@ -15,6 +15,7 @@ use crate::connection::PySqlConnection; use env_logger::{Builder, Target}; use pythonize::depythonize; use serde::{Deserialize, Serialize}; +use vegafusion_common::data::table::VegaFusionTable; use vegafusion_core::proto::gen::tasks::Variable; use vegafusion_core::spec::chart::ChartSpec; use vegafusion_core::task_graph::graph::ScopedVariable; @@ -48,20 +49,27 @@ struct PyVegaFusionRuntime { tokio_runtime: Runtime, } -fn deserialize_inline_datasets( +fn process_inline_datasets( inline_datasets: Option<&PyDict>, ) -> PyResult> { if let Some(inline_datasets) = inline_datasets { - inline_datasets - .iter() - .map(|(name, table_bytes)| { - let name = name.downcast::()?; - let ipc_bytes = table_bytes.downcast::()?; - let ipc_bytes = ipc_bytes.as_bytes(); - let dataset = VegaFusionDataset::from_table_ipc_bytes(ipc_bytes)?; - Ok((name.to_string(), dataset)) - }) - .collect::>>() + Python::with_gil(|py| -> PyResult<_> { + let pyarrow_module = PyModule::import(py, "builtins")?; + let id_fun = pyarrow_module.getattr("id")?; + + inline_datasets + .iter() + .map(|(name, pyarrow_table)| { + // Use object id of the pyarrow table as dataset's fingerprint + let args = PyTuple::new(py, vec![pyarrow_table]); + let id_object = id_fun.call(args, None)?; + let id = id_object.extract::()?; + let table = VegaFusionTable::from_pyarrow(py, pyarrow_table)?; + let dataset = VegaFusionDataset::Table { table, hash: id }; + Ok((name.to_string(), dataset)) + }) + .collect::>>() + }) } else { Ok(Default::default()) } @@ -121,7 +129,7 @@ impl PyVegaFusionRuntime { preserve_interactivity: Option, inline_datasets: Option<&PyDict>, ) -> PyResult<(PyObject, PyObject)> { - let inline_datasets = deserialize_inline_datasets(inline_datasets)?; + let inline_datasets = process_inline_datasets(inline_datasets)?; let spec = parse_json_spec(spec)?; let preserve_interactivity = preserve_interactivity.unwrap_or(false); @@ -181,7 +189,7 @@ impl PyVegaFusionRuntime { row_limit: Option, inline_datasets: Option<&PyDict>, ) -> PyResult<(PyObject, PyObject)> { - let inline_datasets = deserialize_inline_datasets(inline_datasets)?; + let inline_datasets = process_inline_datasets(inline_datasets)?; let spec = parse_json_spec(spec)?; // Build variables