Skip to content

Commit

Permalink
feat: Allow reshaping empty array to Array dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Sep 26, 2024
1 parent 71a8b05 commit 47d670c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 29 deletions.
44 changes: 23 additions & 21 deletions crates/polars-core/src/series/ops/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,42 @@ impl Series {
let size = leaf_array.len();

let mut total_dim_size = 1;
let mut num_infers = 0;
let mut infer_dim_index = None;
for (index, &dim) in dimensions.iter().enumerate() {
match dim {
ReshapeDimension::Infer => {
polars_ensure!(
num_infers == 0,
infer_dim_index.replace(index).is_none(),
InvalidOperation: "can only specify one inferred dimension"
);
num_infers += 1;
},
ReshapeDimension::Specified(dim) => {
let dim = dim.get();

if dim > 0 {
total_dim_size *= dim as usize
} else {
polars_ensure!(
index == 0,
InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}",
format_tuple!(dimensions)
);
total_dim_size = 0;
// We can early exit here, as empty arrays will error with multiple dimensions,
// and non-empty arrays will error when the first dimension is zero.
break;
}
polars_ensure!(
dim != 0 || index == 0,
InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}",
format_tuple!(dimensions)
);

total_dim_size *= dim as usize;
},
}
}

if size == 0 {
if dimensions.len() > 1 || (num_infers == 0 && total_dim_size != 0) {
// we need to infer a zero but can't
if total_dim_size != 0 && infer_dim_index != Some(0)
// we need to infer a non-zero but don't have enough information
|| total_dim_size == 0 && infer_dim_index.is_some()
{
polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dimensions))
}
} else if total_dim_size == 0 {
polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions))
} else {
polars_ensure!(
total_dim_size > 0,
InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions)
);
polars_ensure!(
size % total_dim_size == 0,
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
Expand All @@ -146,8 +145,11 @@ impl Series {
for idx in (1..dimensions.len()).rev() {
// Infer dimension if needed
let dim = dimensions[idx].get_or_infer_with(|| {
debug_assert!(num_infers > 0);
(size / total_dim_size) as u64
if total_dim_size == 0 {
0
} else {
(size / total_dim_size) as u64
}
});
prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);

Expand Down
4 changes: 3 additions & 1 deletion py-polars/tests/unit/constructors/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def test_series_init_np_2d_zero_zero_shape() -> None:
arr = np.array([]).reshape(0, 0)
with pytest.raises(
InvalidOperationError,
match=re.escape("cannot reshape empty array into shape (0, 0)"),
match=re.escape(
"cannot reshape array into shape containing a zero dimension after the first: (0, 0)"
),
):
pl.Series(arr)

Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_from_numpy_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,12 @@ def test_from_numpy_timedelta(time_unit: TimeUnit) -> None:
assert s.name == "name"
assert s.dt[0] == timedelta(days=1)
assert s.dt[1] == timedelta(seconds=1)


def test_from_zero_length_array() -> None:
a = np.zeros(dtype=np.int32, shape=(0, 4))
s = pl.Series("name", a)

assert s.dtype == pl.Array(pl.Int32, 4)
assert s.name == "name"
assert s.len() == 0
28 changes: 21 additions & 7 deletions py-polars/tests/unit/operations/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,25 @@ def test_reshape_invalid_dimension_size(shape: tuple[int, ...]) -> None:
s.reshape(shape)


def test_reshape_invalid_zero_dimension() -> None:
@pytest.mark.parametrize("shape", [(0, 5), (0, 1, -1)])
def test_reshape_invalid_into_empty(shape: tuple[int, ...]) -> None:
s = pl.Series("a", [1, 2, 3, 4])
shape = (-1, 0)
with pytest.raises(
InvalidOperationError,
match=re.escape(
f"cannot reshape array into shape containing a zero dimension after the first: {display_shape(shape)}"
f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}"
),
):
s.reshape(shape)


@pytest.mark.parametrize("shape", [(0, -1), (0, 4), (0, 0)])
def test_reshape_invalid_zero_dimension2(shape: tuple[int, ...]) -> None:
@pytest.mark.parametrize("shape", [(0, 0), (-1, 0)])
def test_reshape_invalid_zero_dimension(shape: tuple[int, ...]) -> None:
s = pl.Series("a", [1, 2, 3, 4])
with pytest.raises(
InvalidOperationError,
match=re.escape(
f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}"
f"cannot reshape array into shape containing a zero dimension after the first: {display_shape(shape)}"
),
):
s.reshape(shape)
Expand All @@ -100,7 +100,7 @@ def test_reshape_empty_valid_1d(shape: tuple[int, ...]) -> None:
assert_series_equal(out, s)


@pytest.mark.parametrize("shape", [(0, 1), (1, -1), (-1, 1)])
@pytest.mark.parametrize("shape", [(1, -1), (0, -2), (0, 3, -1)])
def test_reshape_empty_invalid_2d(shape: tuple[int, ...]) -> None:
s = pl.Series("a", [], dtype=pl.Int64)
with pytest.raises(
Expand All @@ -112,6 +112,20 @@ def test_reshape_empty_invalid_2d(shape: tuple[int, ...]) -> None:
s.reshape(shape)


@pytest.mark.parametrize(
("shape", "out_dtype"),
[
((0, 5), pl.Array(pl.Int64, 5)),
((-1, 3), pl.Array(pl.Int64, 3)),
((-1, 3, 2), pl.Array(pl.Int64, (3, 2))),
],
)
def test_reshape_empty_valid_2d(shape: tuple[int, ...], out_dtype: pl.DataType) -> None:
s = pl.Series("a", [], dtype=pl.Int64)
out = s.reshape(shape)
assert out.dtype == out_dtype


@pytest.mark.parametrize("shape", [(1,), (2,)])
def test_reshape_empty_invalid_1d(shape: tuple[int, ...]) -> None:
s = pl.Series("a", [], dtype=pl.Int64)
Expand Down

0 comments on commit 47d670c

Please sign in to comment.