-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.rs
111 lines (94 loc) · 4.14 KB
/
utils.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use std::ffi::CString;
use std::sync::Arc;
use arrow::compute::can_cast_types;
use arrow::compute::kernels::cast;
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow_array::Array;
use arrow_schema::{ArrowError, DataType, Field, FieldRef};
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};
use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_schema_pycapsule;
use crate::ffi::to_python::ffi_stream::new_stream;
use crate::ffi::{ArrayIterator, ArrayReader};
/// Export a [`arrow_schema::Schema`], [`arrow_schema::Field`], or [`arrow_schema::DataType`] to a
/// PyCapsule holding an Arrow C Schema pointer.
pub fn to_schema_pycapsule(
py: Python,
field: impl TryInto<FFI_ArrowSchema, Error = ArrowError>,
) -> PyArrowResult<Bound<PyCapsule>> {
let ffi_schema: FFI_ArrowSchema = field.try_into()?;
let schema_capsule_name = CString::new("arrow_schema").unwrap();
let schema_capsule = PyCapsule::new_bound(py, ffi_schema, Some(schema_capsule_name))?;
Ok(schema_capsule)
}
/// Export an [`Array`] and [`FieldRef`] to a tuple of PyCapsules holding an Arrow C Schema and
/// Arrow C Array pointers.
pub fn to_array_pycapsules<'py>(
py: Python<'py>,
field: FieldRef,
array: &dyn Array,
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyArrowResult<Bound<'py, PyTuple>> {
// Cast array if requested
let (array_data, field) = if let Some(capsule) = requested_schema {
let schema_ptr = import_schema_pycapsule(&capsule)?;
// Note: we don't import a Field directly because the name might not be set.
// https://github.com/apache/arrow-rs/issues/6251
let data_type = DataType::try_from(schema_ptr)?;
// Only cast the array if we can cast the types.
if can_cast_types(field.data_type(), &data_type) {
let field =
Arc::new(Field::new("", data_type, true).with_metadata(field.metadata().clone()));
let casted_array = cast(array, field.data_type())?;
(casted_array.to_data(), field)
} else {
(array.to_data(), field)
}
} else {
(array.to_data(), field)
};
let ffi_schema = FFI_ArrowSchema::try_from(&field)?;
let ffi_array = FFI_ArrowArray::new(&array_data);
let schema_capsule_name = CString::new("arrow_schema").unwrap();
let array_capsule_name = CString::new("arrow_array").unwrap();
let schema_capsule = PyCapsule::new_bound(py, ffi_schema, Some(schema_capsule_name))?;
let array_capsule = PyCapsule::new_bound(py, ffi_array, Some(array_capsule_name))?;
let tuple = PyTuple::new_bound(py, vec![schema_capsule, array_capsule]);
Ok(tuple)
}
/// Export an [`ArrayIterator`][crate::ffi::ArrayIterator] to a PyCapsule holding an Arrow C Stream
/// pointer.
pub fn to_stream_pycapsule<'py>(
py: Python<'py>,
mut array_reader: Box<dyn ArrayReader + Send>,
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyArrowResult<Bound<'py, PyCapsule>> {
// Cast array if requested
if let Some(capsule) = requested_schema {
let schema_ptr = import_schema_pycapsule(&capsule)?;
let existing_field = array_reader.field();
// Note: we don't import a Field directly because the name might not be set.
// https://github.com/apache/arrow-rs/issues/6251
let data_type = DataType::try_from(schema_ptr)?;
// Only cast the reader if we can cast the types.
if can_cast_types(existing_field.data_type(), &data_type) {
let field = Arc::new(
Field::new("", data_type, true).with_metadata(existing_field.metadata().clone()),
);
let output_field = field.clone();
let array_iter = array_reader.map(move |array| {
let out = cast(array?.as_ref(), field.data_type())?;
Ok(out)
});
array_reader = Box::new(ArrayIterator::new(array_iter, output_field));
}
}
let ffi_stream = new_stream(array_reader);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
Ok(PyCapsule::new_bound(
py,
ffi_stream,
Some(stream_capsule_name),
)?)
}