Skip to content

Commit

Permalink
fix construction
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Dec 25, 2023
1 parent ebf1ea6 commit 55774d2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
37 changes: 24 additions & 13 deletions py-polars/src/series/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,30 @@ impl PySeries {
Ok(series.into())
} else {
let val = vec_extract_wrapped(val);
let series = Series::new(name, &val);
match series.dtype() {
DataType::List(list_inner) => {
let series = series
.cast(&DataType::Array(
Box::new(inner.map(|dt| dt.0).unwrap_or(*list_inner.clone())),
width,
))
.map_err(PyPolarsErr::from)?;
Ok(series.into())
},
_ => Err(PyValueError::new_err("could not create Array from input")),
}
return if let Some(inner) = inner {
let series = Series::from_any_values_and_dtype(
name,
val.as_ref(),
&DataType::Array(Box::new(inner.0), width),
true,
)
.map_err(PyPolarsErr::from)?;
Ok(series.into())
} else {
let series = Series::new(name, &val);
match series.dtype() {
DataType::List(list_inner) => {
let series = series
.cast(&DataType::Array(
Box::new(inner.map(|dt| dt.0).unwrap_or(*list_inner.clone())),
width,
))
.map_err(PyPolarsErr::from)?;
Ok(series.into())
},
_ => Err(PyValueError::new_err("could not create Array from input")),
}
};
}
}

Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/io/test_json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import io
import json
from collections import OrderedDict
Expand Down Expand Up @@ -299,6 +300,13 @@ def test_write_json_duration() -> None:
[[[1, 2, 3], [4, None]], None, [[None, None, 2]]],
pl.List(pl.Array(pl.Int32(), width=3)),
),
(
[
[datetime.datetime(1991, 1, 1), datetime.datetime(1991, 1, 1), None],
[None, None, None],
],
pl.Array(pl.Datetime, width=3),
),
],
)
def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
Expand All @@ -310,6 +318,44 @@ def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
assert_frame_equal(deserialized_df, df)


@pytest.mark.parametrize(
("data", "dtype"),
[
(
[
[
datetime.datetime(1997, 10, 1),
datetime.datetime(2000, 1, 2, 10, 30, 1),
],
[None, None],
],
pl.Array(pl.Datetime, width=2),
),
(
[[datetime.date(1997, 10, 1), datetime.date(2000, 1, 1)], [None, None]],
pl.Array(pl.Date, width=2),
),
(
[
[datetime.timedelta(seconds=1), datetime.timedelta(seconds=10)],
[None, None],
],
pl.Array(pl.Duration, width=2),
),
],
)
def test_write_read_json_array_logical_inner_type(
data: Any, dtype: pl.DataType
) -> None:
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
buf = io.StringIO()
df.write_json(buf)
buf.seek(0)
deserialized_df = pl.read_json(buf)
assert deserialized_df.dtypes == df.dtypes
assert deserialized_df.to_dict(as_series=False) == df.to_dict(as_series=False)


def test_json_null_infer() -> None:
json = BytesIO(
bytes(
Expand Down

0 comments on commit 55774d2

Please sign in to comment.