From 7fb5fae8f79b3db4a94013aa2af7c63796ef2d64 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 11 Oct 2023 23:14:02 +0200 Subject: [PATCH] Fix ArrayXD cast --- src/datasets/table.py | 2 +- tests/test_table.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index e021dea1092..e85a64227af 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1964,7 +1964,7 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True): if isinstance(array, pa.ExtensionArray): array = array.storage if isinstance(pa_type, pa.ExtensionType): - return pa_type.wrap_array(array) + return pa_type.wrap_array(_c(array, pa_type.storage_type)) elif array.type == pa_type: return array elif pa.types.is_struct(array.type): diff --git a/tests/test_table.py b/tests/test_table.py index b20e509d1b8..ba6dfbfb5f2 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -9,7 +9,7 @@ import datasets from datasets import Sequence, Value -from datasets.features.features import Array2DExtensionType, ClassLabel, Features, Image +from datasets.features.features import Array2D, Array2DExtensionType, ClassLabel, Features, Image from datasets.table import ( ConcatenationTable, InMemoryTable, @@ -1165,6 +1165,16 @@ def test_cast_array_to_features_to_null_type(): cast_array_to_feature(arr, Sequence(Value("null"))) +def test_cast_array_to_features_array_xd(): + # same storage type + arr = pa.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], pa.list_(pa.list_(pa.int32(), 2), 2)) + casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="int32")) + assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="int32") + # different storage type + casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="float32")) + assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="float32") + + def test_cast_array_to_features_sequence_classlabel(): arr = pa.array([[], [1], [0, 1]], pa.list_(pa.int64())) assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())