Skip to content

Commit

Permalink
feat(python): Handle textio even if not correct
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 15, 2024
1 parent 9a3e032 commit 967be51
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 19 deletions.
4 changes: 2 additions & 2 deletions crates/polars-io/src/json/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ where
/// Because JSON values specify their types (number, string, etc), no upcasting or conversion is performed between
/// incompatible types in the input. In the event that a column contains mixed dtypes, is it unspecified whether an
/// error is returned or whether elements of incompatible dtypes are replaced with `null`.
fn finish(self) -> PolarsResult<DataFrame> {
let rb: ReaderBytes = (&self.reader).into();
fn finish(mut self) -> PolarsResult<DataFrame> {
let rb: ReaderBytes = (&mut self.reader).into();

let out = match self.json_format {
JsonFormat::Json => {
Expand Down
26 changes: 20 additions & 6 deletions crates/polars-io/src/mmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Mutex;

use memmap::Mmap;
use once_cell::sync::Lazy;
use polars_core::config::verbose;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::create_file;

Expand Down Expand Up @@ -130,14 +131,27 @@ impl std::ops::Deref for ReaderBytes<'_> {
}
}

impl<'a, T: 'a + MmapBytesReader> From<&'a T> for ReaderBytes<'a> {
fn from(m: &'a T) -> Self {
impl<'a, T: 'a + MmapBytesReader> From<&'a mut T> for ReaderBytes<'a> {
fn from(m: &'a mut T) -> Self {
match m.to_bytes() {
Some(s) => ReaderBytes::Borrowed(s),
// , but somehow bchk doesn't see that lifetime is 'a.
Some(s) => {
let s = unsafe { std::mem::transmute::<&[u8], &'a [u8]>(s) };
ReaderBytes::Borrowed(s)
},
None => {
let f = m.to_file().unwrap();
let mmap = unsafe { memmap::Mmap::map(f).unwrap() };
ReaderBytes::Mapped(mmap, f)
if let Some(f) = m.to_file() {
let f = unsafe { std::mem::transmute::<&File, &'a File>(f) };
let mmap = unsafe { memmap::Mmap::map(f).unwrap() };
ReaderBytes::Mapped(mmap, f)
} else {
if verbose() {
eprintln!("could not memory map file; read to buffer.")
}
let mut buf = vec![];
m.read_to_end(&mut buf).expect("could not read");
ReaderBytes::Owned(buf)
}
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ pub fn read_parquet<R: MmapBytesReader>(
parallel = ParallelStrategy::None;
}

let reader = ReaderBytes::from(&reader);
let reader = ReaderBytes::from(&mut reader);
let bytes = reader.deref();
let store = mmap::ColumnStore::Local(bytes);

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ pub struct CategoricalNameSpace(pub(crate) Expr);
impl CategoricalNameSpace {
pub fn get_categories(self) -> Expr {
self.0
.map_private(CategoricalFunction::GetCategories.into())
.apply_private(CategoricalFunction::GetCategories.into())
}
}
4 changes: 2 additions & 2 deletions py-polars/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ impl PyDataFrame {
pub fn deserialize(py: Python, mut py_f: Bound<PyAny>) -> PyResult<Self> {
use crate::file::read_if_bytesio;
py_f = read_if_bytesio(py_f);
let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;
let mut mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;

py.allow_threads(move || {
let mmap_read: ReaderBytes = (&mmap_bytes_r).into();
let mmap_read: ReaderBytes = (&mut mmap_bytes_r).into();
let bytes = mmap_read.deref();
match serde_json::from_slice::<DataFrame>(bytes) {
Ok(df) => Ok(df.into()),
Expand Down
23 changes: 16 additions & 7 deletions py-polars/src/file.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::fs::File;
use std::io;
use std::io::{BufReader, Cursor, Read, Seek, SeekFrom, Write};
use std::io::{BufReader, Cursor, ErrorKind, Read, Seek, SeekFrom, Write};
use std::path::PathBuf;

use polars::io::mmap::MmapBytesReader;
use polars_error::polars_warn;
use polars_error::{polars_err, polars_warn};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyString};
Expand Down Expand Up @@ -104,13 +104,22 @@ impl Read for PyFileLikeObject {
.call_method_bound(py, "read", (buf.len(),), None)
.map_err(pyerr_to_io_err)?;

let bytes: &Bound<'_, PyBytes> = bytes
.downcast_bound(py)
.expect("Expecting to be able to downcast into bytes from read result.");
let opt_bytes = bytes.downcast_bound::<PyBytes>(py);

buf.write_all(bytes.as_bytes())?;
if let Ok(bytes) = opt_bytes {
buf.write_all(bytes.as_bytes())?;

bytes.len().map_err(pyerr_to_io_err)
bytes.len().map_err(pyerr_to_io_err)
} else if let Ok(s) = bytes.downcast_bound::<PyString>(py) {
let s = s.to_cow().map_err(pyerr_to_io_err)?;
buf.write_all(s.as_bytes())?;
Ok(s.len())
} else {
Err(io::Error::new(
ErrorKind::InvalidInput,
polars_err!(InvalidOperation: "could not read from input"),
))
}
})
}
}
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

import io
import json
import typing
from collections import OrderedDict
from io import BytesIO
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path

import pytest

Expand Down Expand Up @@ -287,3 +292,19 @@ def test_ndjson_null_inference_13183() -> None:
"start_time": [0.795, 1.6239999999999999, 2.184, None],
"end_time": [1.495, 2.0540000000000003, 2.645, None],
}


@pytest.mark.write_disk()
@typing.no_type_check
def test_json_wrong_input_handle_textio(tmp_path: Path) -> None:
# this shouldn't be passed, but still we test if we can handle it gracefully
df = pl.DataFrame(
{
"x": [1, 2, 3],
"y": ["a", "b", "c"],
}
)
file_path = tmp_path / "test.ndjson"
df.write_ndjson(file_path)
with open(file_path) as f: # noqa: PTH123
assert_frame_equal(pl.read_ndjson(f), df)

0 comments on commit 967be51

Please sign in to comment.