From 2597a2624654b389f7aa468ab1929b1a8c528de6 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Mon, 15 Apr 2024 15:13:05 +0200 Subject: [PATCH 1/3] add allow_primitive_to_str and allow_decimal_to_str instead of allow_number_to_str --- src/datasets/arrow_writer.py | 8 ++-- src/datasets/table.py | 78 +++++++++++++++++++++++++++--------- tests/test_table.py | 38 +++++++++++++++++- 3 files changed, 99 insertions(+), 25 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 82e72a91ecc..f84134e78b8 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -205,7 +205,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): # We use cast_array_to_feature to support casting to custom types like Audio and Image # Also, when trying type "string", we don't want to convert integers or floats to "string". # We only do it if trying_type is False - since this is what the user asks for. - out = cast_array_to_feature(out, type, allow_number_to_str=not self.trying_type) + out = cast_array_to_feature( + out, type, allow_primitive_to_str=not self.trying_type, allow_decimal_to_str=not self.trying_type + ) return out except ( TypeError, @@ -241,7 +243,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False) ) if type is not None: - out = cast_array_to_feature(out, type, allow_number_to_str=True) + out = cast_array_to_feature(out, type, allow_primitive_to_str=True) return out else: raise @@ -256,7 +258,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): elif trying_cast_to_python_objects and "Could not convert" in str(e): out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False)) if type is not None: - out = cast_array_to_feature(out, type, allow_number_to_str=True) + out = cast_array_to_feature(out, type, allow_primitive_to_str=True) return out else: raise diff --git a/src/datasets/table.py b/src/datasets/table.py index 366f8bc185d..047925f1aca 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1838,20 +1838,26 @@ def _storage_type(type: pa.DataType) -> pa.DataType: @_wrap_for_chunked_arrays -def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True): +def array_cast( + array: pa.Array, pa_type: pa.DataType, allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True +) -> Union[pa.Array, pa.FixedSizeListArray, pa.ListArray, pa.StructArray, pa.ExtensionArray]: """Improved version of `pa.Array.cast` It supports casting `pa.StructArray` objects to re-order the fields. It also let you control certain aspects of the casting, e.g. whether - to disable numbers (`floats` or `ints`) to strings. + to disable casting primitives (`booleans`, `floats` or `ints`) or + disable casting decimals to strings. Args: array (`pa.Array`): PyArrow array to cast pa_type (`pa.DataType`): Target PyArrow type - allow_number_to_str (`bool`, defaults to `True`): - Whether to allow casting numbers to strings. + allow_primitive_to_str (`bool`, defaults to `True`): + Whether to allow casting primitives to strings. + Defaults to `True`. + allow_decimal_to_str (`bool`, defaults to `True`): + Whether to allow casting decimals to strings. Defaults to `True`. Raises: @@ -1859,12 +1865,13 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True): `TypeError`: if the target type is not supported according, e.g. - if a field is missing - - if casting from numbers to strings and `allow_number_to_str` is `False` + - if casting from primitives to strings and `allow_primitive_to_str` is `False` + - if casting from decimals to strings and `allow_decimal_to_str` is `False` Returns: `List[pyarrow.Array]`: the casted array """ - _c = partial(array_cast, allow_number_to_str=allow_number_to_str) + _c = partial(array_cast, allow_primitive_to_str=allow_primitive_to_str, allow_decimal_to_str=allow_decimal_to_str) if isinstance(array, pa.ExtensionArray): array = array.storage if isinstance(pa_type, pa.ExtensionType): @@ -1933,13 +1940,14 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True): array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size return pa.ListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type), mask=array.is_null()) else: - if ( - not allow_number_to_str - and pa.types.is_string(pa_type) - and (pa.types.is_floating(array.type) or pa.types.is_integer(array.type)) + if pa.types.is_string(pa_type) and ( + (not allow_primitive_to_str and pa.types.is_primitive(array.type)) + or (not allow_decimal_to_str and pa.types.is_decimal(array.type)) ): raise TypeError( - f"Couldn't cast array of type {array.type} to {pa_type} since allow_number_to_str is set to {allow_number_to_str}" + f"Couldn't cast array of type {array.type} to {pa_type} " + f"since allow_primitive_to_str is set to {allow_primitive_to_str} " + f"and allow_decimal_to_str is set to {allow_decimal_to_str}." ) if pa.types.is_null(pa_type) and not pa.types.is_null(array.type): raise TypeError(f"Couldn't cast array of type {array.type} to {pa_type}") @@ -1948,7 +1956,9 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True): @_wrap_for_chunked_arrays -def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_to_str=True): +def cast_array_to_feature( + array: pa.Array, feature: "FeatureType", allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True +) -> pa.Array: """Cast an array to the arrow type that corresponds to the requested feature type. For custom features like [`Audio`] or [`Image`], it takes into account the "cast_storage" methods they defined to enable casting from other arrow types. @@ -1958,8 +1968,11 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_ The PyArrow array to cast. feature (`datasets.features.FeatureType`): The target feature type. - allow_number_to_str (`bool`, defaults to `True`): - Whether to allow casting numbers to strings. + allow_primitive_to_str (`bool`, defaults to `True`): + Whether to allow casting primitives to strings. + Defaults to `True`. + allow_decimal_to_str (`bool`, defaults to `True`): + Whether to allow casting decimals to strings. Defaults to `True`. Raises: @@ -1967,14 +1980,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_ `TypeError`: if the target type is not supported according, e.g. - if a field is missing - - if casting from numbers to strings and `allow_number_to_str` is `False` + - if casting from primitives and `allow_primitive_to_str` is `False` + - if casting from decimals and `allow_decimal_to_str` is `False` Returns: array (`pyarrow.Array`): the casted array """ from .features.features import Sequence, get_nested_type - _c = partial(cast_array_to_feature, allow_number_to_str=allow_number_to_str) + _c = partial( + cast_array_to_feature, + allow_primitive_to_str=allow_primitive_to_str, + allow_decimal_to_str=allow_decimal_to_str, + ) if isinstance(array, pa.ExtensionArray): array = array.storage @@ -2011,9 +2029,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_ storage_type = _storage_type(array_type) if array_type != storage_type: # Temporarily convert to the storage type to support extension types in the slice operation - array = array_cast(array, storage_type, allow_number_to_str=allow_number_to_str) + array = array_cast( + array, + storage_type, + allow_primitive_to_str=allow_primitive_to_str, + allow_decimal_to_str=allow_decimal_to_str, + ) array = pc.list_slice(array, 0, feature.length, return_fixed_size_list=True) - array = array_cast(array, array_type, allow_number_to_str=allow_number_to_str) + array = array_cast( + array, + array_type, + allow_primitive_to_str=allow_primitive_to_str, + allow_decimal_to_str=allow_decimal_to_str, + ) else: array = pc.list_slice(array, 0, feature.length, return_fixed_size_list=True) array_values = array.values @@ -2069,9 +2097,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_ array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature.feature), mask=array.is_null()) if pa.types.is_null(array.type): - return array_cast(array, get_nested_type(feature), allow_number_to_str=allow_number_to_str) + return array_cast( + array, + get_nested_type(feature), + allow_primitive_to_str=allow_primitive_to_str, + allow_decimal_to_str=allow_decimal_to_str, + ) elif not isinstance(feature, (Sequence, dict, list, tuple)): - return array_cast(array, feature(), allow_number_to_str=allow_number_to_str) + return array_cast( + array, + feature(), + allow_primitive_to_str=allow_primitive_to_str, + allow_decimal_to_str=allow_decimal_to_str, + ) raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{feature}") diff --git a/tests/test_table.py b/tests/test_table.py index 9ac6b7ee7b9..b63d25bdd11 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1,5 +1,6 @@ import copy import pickle +from decimal import Decimal from typing import List, Union import numpy as np @@ -1098,11 +1099,44 @@ def test_indexed_table_mixin(): assert table.fast_slice(2, 13) == pa_table.slice(2, 13) -def test_cast_array_to_features(): +def test_cast_integer_array_to_features(): arr = pa.array([[0, 1]]) assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string()) + assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_( + pa.string() + ) + with pytest.raises(TypeError): + cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False) + + +def test_cast_float_array_to_features(): + arr = pa.array([[0.0, 1.0]]) + assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string()) + assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_( + pa.string() + ) + with pytest.raises(TypeError): + cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False) + + +def test_cast_boolean_array_to_features(): + arr = pa.array([[False, True]]) + assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string()) + assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_( + pa.string() + ) + with pytest.raises(TypeError): + cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False) + + +def test_cast_decimal_array_to_features(): + arr = pa.array([[Decimal(0), Decimal(1)]]) + assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string()) + assert cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False).type == pa.list_( + pa.string() + ) with pytest.raises(TypeError): - cast_array_to_feature(arr, Sequence(Value("string")), allow_number_to_str=False) + cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False) def test_cast_array_to_features_nested(): From b3ba93c418568db5951f59a9dc3042187b9b43e0 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Mon, 15 Apr 2024 22:35:17 +0200 Subject: [PATCH 2/3] add missing allow_decimal_str and split typerrors --- src/datasets/arrow_writer.py | 4 ++-- src/datasets/table.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index f84134e78b8..174dc28df69 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -243,7 +243,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False) ) if type is not None: - out = cast_array_to_feature(out, type, allow_primitive_to_str=True) + out = cast_array_to_feature(out, type, allow_primitive_to_str=True, allow_decimal_to_str=True) return out else: raise @@ -258,7 +258,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): elif trying_cast_to_python_objects and "Could not convert" in str(e): out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False)) if type is not None: - out = cast_array_to_feature(out, type, allow_primitive_to_str=True) + out = cast_array_to_feature(out, type, allow_primitive_to_str=True, allow_decimal_to_str=True) return out else: raise diff --git a/src/datasets/table.py b/src/datasets/table.py index 047925f1aca..8ba8d0f0304 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1940,15 +1940,17 @@ def array_cast( array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size return pa.ListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type), mask=array.is_null()) else: - if pa.types.is_string(pa_type) and ( - (not allow_primitive_to_str and pa.types.is_primitive(array.type)) - or (not allow_decimal_to_str and pa.types.is_decimal(array.type)) - ): - raise TypeError( - f"Couldn't cast array of type {array.type} to {pa_type} " - f"since allow_primitive_to_str is set to {allow_primitive_to_str} " - f"and allow_decimal_to_str is set to {allow_decimal_to_str}." - ) + if pa.types.is_string(pa_type): + if not allow_primitive_to_str and pa.types.is_primitive(array.type): + raise TypeError( + f"Couldn't cast array of type {array.type} to {pa_type} " + f"since allow_primitive_to_str is set to {allow_primitive_to_str} " + ) + if not allow_decimal_to_str and pa.types.is_decimal(array.type): + raise TypeError( + f"Couldn't cast array of type {array.type} to {pa_type} " + f"and allow_decimal_to_str is set to {allow_decimal_to_str}" + ) if pa.types.is_null(pa_type) and not pa.types.is_null(array.type): raise TypeError(f"Couldn't cast array of type {array.type} to {pa_type}") return array.cast(pa_type) From 85a729f0e37bdc8eec48ed2afa79abde5fbbc33e Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 16 Apr 2024 17:56:45 +0200 Subject: [PATCH 3/3] Style --- src/datasets/arrow_writer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 174dc28df69..ea12d094e63 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -243,7 +243,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False) ) if type is not None: - out = cast_array_to_feature(out, type, allow_primitive_to_str=True, allow_decimal_to_str=True) + out = cast_array_to_feature( + out, type, allow_primitive_to_str=True, allow_decimal_to_str=True + ) return out else: raise