Skip to content

Commit

Permalink
Fix sequence of array support for most dtype (#5948)
Browse files Browse the repository at this point in the history
* remove array.tolist() from ArrayXD encode

* fix insert

* Update tests
  • Loading branch information
qgallouedec authored Jun 14, 2023
1 parent f1911ff commit 650a86e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
5 changes: 2 additions & 3 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,6 @@ def __call__(self):
return pa_type

def encode_example(self, value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value


Expand Down Expand Up @@ -1397,7 +1395,8 @@ def numpy_to_pyarrow_listarray(arr: np.ndarray, type: pa.DataType = None) -> pa.


def list_of_pa_arrays_to_pyarrow_listarray(l_arr: List[Optional[pa.Array]]) -> pa.ListArray:
null_indices = [i for i, arr in enumerate(l_arr) if arr is None]
null_mask = np.array([arr is None for arr in l_arr])
null_indices = np.arange(len(null_mask))[null_mask] - np.arange(np.sum(null_mask))
l_arr = [arr for arr in l_arr if arr is not None]
offsets = np.cumsum(
[0] + [len(arr) for arr in l_arr], dtype=object
Expand Down
39 changes: 28 additions & 11 deletions tests/features/test_array_xd.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,21 +397,38 @@ def test_array_xd_with_none():
assert np.isnan(arr[1]) and np.isnan(arr[3]) # a single np.nan value - np.all not needed


@pytest.mark.parametrize("seq_type", ["no_sequence", "sequence", "sequence_of_sequence"])
@pytest.mark.parametrize(
"data, feature, expected",
"dtype",
[
(np.zeros((2, 2)), None, [[0.0, 0.0], [0.0, 0.0]]),
(np.zeros((2, 3)), datasets.Array2D(shape=(2, 3), dtype="float32"), [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
([np.zeros(2)], datasets.Array2D(shape=(1, 2), dtype="float32"), [[0.0, 0.0]]),
(
[np.zeros((2, 3))],
datasets.Array3D(shape=(1, 2, 3), dtype="float32"),
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
),
"bool",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
],
)
def test_array_xd_with_np(data, feature, expected):
ds = datasets.Dataset.from_dict({"col": [data]}, features=datasets.Features({"col": feature}) if feature else None)
@pytest.mark.parametrize("shape, feature_class", [((2, 3), datasets.Array2D), ((2, 3, 4), datasets.Array3D)])
def test_array_xd_with_np(seq_type, dtype, shape, feature_class):
feature = feature_class(dtype=dtype, shape=shape)
data = np.zeros(shape, dtype=dtype)
expected = data.tolist()
if seq_type == "sequence":
feature = datasets.Sequence(feature)
data = [data]
expected = [expected]
elif seq_type == "sequence_of_sequence":
feature = datasets.Sequence(datasets.Sequence(feature))
data = [[data]]
expected = [[expected]]
ds = datasets.Dataset.from_dict({"col": [data]}, features=datasets.Features({"col": feature}))
assert ds[0]["col"] == expected


Expand Down

0 comments on commit 650a86e

Please sign in to comment.