Skip to content

Commit

Permalink
test(python): Add test for BytesIO overwritten after scan (#20240)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Dec 10, 2024
1 parent 2c76494 commit c2ef569
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 25 deletions.
18 changes: 6 additions & 12 deletions crates/polars-python/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::conversion::Wrap;
use crate::error::PyPolarsErr;
use crate::file::{
get_either_file, get_file_like, get_mmap_bytes_reader, get_mmap_bytes_reader_and_path,
read_if_bytesio, EitherRustPythonFile,
EitherRustPythonFile,
};
use crate::prelude::{parse_cloud_options, PyCompatLevel};

Expand All @@ -37,7 +37,7 @@ impl PyDataFrame {
)]
pub fn read_csv(
py: Python,
mut py_f: Bound<PyAny>,
py_f: Bound<PyAny>,
infer_schema_length: Option<usize>,
chunk_size: usize,
has_header: bool,
Expand Down Expand Up @@ -92,7 +92,6 @@ impl PyDataFrame {
.collect::<Vec<_>>()
});

py_f = read_if_bytesio(py_f);
let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;
let df = py.allow_threads(move || {
CsvReadOptions::default()
Expand Down Expand Up @@ -193,14 +192,12 @@ impl PyDataFrame {
#[pyo3(signature = (py_f, infer_schema_length, schema, schema_overrides))]
pub fn read_json(
py: Python,
mut py_f: Bound<PyAny>,
py_f: Bound<PyAny>,
infer_schema_length: Option<usize>,
schema: Option<Wrap<Schema>>,
schema_overrides: Option<Wrap<Schema>>,
) -> PyResult<Self> {
assert!(infer_schema_length != Some(0));
use crate::file::read_if_bytesio;
py_f = read_if_bytesio(py_f);
let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;

py.allow_threads(move || {
Expand All @@ -226,12 +223,11 @@ impl PyDataFrame {
#[pyo3(signature = (py_f, ignore_errors, schema, schema_overrides))]
pub fn read_ndjson(
py: Python,
mut py_f: Bound<PyAny>,
py_f: Bound<PyAny>,
ignore_errors: bool,
schema: Option<Wrap<Schema>>,
schema_overrides: Option<Wrap<Schema>>,
) -> PyResult<Self> {
py_f = read_if_bytesio(py_f);
let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;

let mut builder = JsonReader::new(mmap_bytes_r)
Expand All @@ -257,7 +253,7 @@ impl PyDataFrame {
#[pyo3(signature = (py_f, columns, projection, n_rows, row_index, memory_map))]
pub fn read_ipc(
py: Python,
mut py_f: Bound<PyAny>,
py_f: Bound<PyAny>,
columns: Option<Vec<String>>,
projection: Option<Vec<usize>>,
n_rows: Option<usize>,
Expand All @@ -268,7 +264,6 @@ impl PyDataFrame {
name: name.into(),
offset,
});
py_f = read_if_bytesio(py_f);
let (mmap_bytes_r, mmap_path) = get_mmap_bytes_reader_and_path(&py_f)?;

let mmap_path = if memory_map { mmap_path } else { None };
Expand All @@ -290,7 +285,7 @@ impl PyDataFrame {
#[pyo3(signature = (py_f, columns, projection, n_rows, row_index, rechunk))]
pub fn read_ipc_stream(
py: Python,
mut py_f: Bound<PyAny>,
py_f: Bound<PyAny>,
columns: Option<Vec<String>>,
projection: Option<Vec<usize>>,
n_rows: Option<usize>,
Expand All @@ -301,7 +296,6 @@ impl PyDataFrame {
name: name.into(),
offset,
});
py_f = read_if_bytesio(py_f);
let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;
let df = py.allow_threads(move || {
IpcStreamReader::new(mmap_bytes_r)
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-python/src/dataframe/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ impl PyDataFrame {
/// Deserialize a file-like object containing JSON string data into a DataFrame.
#[staticmethod]
#[cfg(feature = "json")]
pub fn deserialize_json(py: Python, mut py_f: Bound<PyAny>) -> PyResult<Self> {
use crate::file::read_if_bytesio;
py_f = read_if_bytesio(py_f);
pub fn deserialize_json(py: Python, py_f: Bound<PyAny>) -> PyResult<Self> {
let mut mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;

py.allow_threads(move || {
Expand Down
28 changes: 18 additions & 10 deletions crates/polars-python/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,12 @@ pub fn get_file_like(f: PyObject, truncate: bool) -> PyResult<Box<dyn FileLike>>
}

/// If the give file-like is a BytesIO, read its contents.
pub fn read_if_bytesio(py_f: Bound<PyAny>) -> Bound<PyAny> {
fn read_if_bytesio(py_f: Bound<PyAny>) -> Bound<PyAny> {
if py_f.getattr("read").is_ok() {
let Ok(bytes) = py_f.call_method0("getvalue") else {
return py_f;
};
if bytes.downcast::<PyBytes>().is_ok() {
if bytes.downcast::<PyBytes>().is_ok() || bytes.downcast::<PyString>().is_ok() {
return bytes.clone();
}
}
Expand All @@ -386,24 +386,32 @@ pub fn read_if_bytesio(py_f: Bound<PyAny>) -> Bound<PyAny> {

/// Create reader from PyBytes or a file-like object. To get BytesIO to have
/// better performance, use read_if_bytesio() before calling this.
pub fn get_mmap_bytes_reader<'a>(
py_f: &'a Bound<'a, PyAny>,
) -> PyResult<Box<dyn MmapBytesReader + 'a>> {
pub fn get_mmap_bytes_reader(py_f: &Bound<PyAny>) -> PyResult<Box<dyn MmapBytesReader>> {
get_mmap_bytes_reader_and_path(py_f).map(|t| t.0)
}

pub fn get_mmap_bytes_reader_and_path<'a>(
py_f: &'a Bound<'a, PyAny>,
) -> PyResult<(Box<dyn MmapBytesReader + 'a>, Option<PathBuf>)> {
pub fn get_mmap_bytes_reader_and_path(
py_f: &Bound<PyAny>,
) -> PyResult<(Box<dyn MmapBytesReader>, Option<PathBuf>)> {
let py_f = read_if_bytesio(py_f.clone());

// bytes object
if let Ok(bytes) = py_f.downcast::<PyBytes>() {
Ok((Box::new(Cursor::new(bytes.as_bytes())), None))
Ok((
Box::new(Cursor::new(MemSlice::from_arc(
bytes.as_bytes(),
Arc::new(py_f.to_object(py_f.py())),
))),
None,
))
}
// string so read file
else {
match get_either_buffer_or_path(py_f.to_object(py_f.py()), false)? {
(EitherRustPythonFile::Rust(f), path) => Ok((Box::new(f), path)),
(EitherRustPythonFile::Py(f), path) => Ok((Box::new(f), path)),
(EitherRustPythonFile::Py(f), path) => {
Ok((Box::new(Cursor::new(f.to_memslice())), path))
},
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,20 @@ def test_scan_in_memory(method: str) -> None:
assert_frame_equal(df.vstack(df).slice(-1, 1), result)


def test_scan_pyobject_zero_copy_buffer_mutate() -> None:
f = io.BytesIO()

df = pl.DataFrame({"x": [1, 2, 3, 4, 5]})
df.write_ipc(f)
f.seek(0)

q = pl.scan_ipc(f)
assert_frame_equal(q.collect(), df)

f.write(b"AAA")
assert_frame_equal(q.collect(), df)


@pytest.mark.parametrize(
"method",
["csv", "ndjson"],
Expand Down

0 comments on commit c2ef569

Please sign in to comment.