Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add allow_primitive_to_str and allow_decimal_to_str instead of allow_number_to_str #6811

Merged
merged 4 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
# We use cast_array_to_feature to support casting to custom types like Audio and Image
# Also, when trying type "string", we don't want to convert integers or floats to "string".
# We only do it if trying_type is False - since this is what the user asks for.
out = cast_array_to_feature(out, type, allow_number_to_str=not self.trying_type)
out = cast_array_to_feature(
out, type, allow_primitive_to_str=not self.trying_type, allow_decimal_to_str=not self.trying_type
)
return out
except (
TypeError,
Expand Down Expand Up @@ -241,7 +243,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False)
)
if type is not None:
out = cast_array_to_feature(out, type, allow_number_to_str=True)
out = cast_array_to_feature(out, type, allow_primitive_to_str=True)
Modexus marked this conversation as resolved.
Show resolved Hide resolved
return out
else:
raise
Expand All @@ -256,7 +258,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
elif trying_cast_to_python_objects and "Could not convert" in str(e):
out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False))
if type is not None:
out = cast_array_to_feature(out, type, allow_number_to_str=True)
out = cast_array_to_feature(out, type, allow_primitive_to_str=True)
Modexus marked this conversation as resolved.
Show resolved Hide resolved
return out
else:
raise
Expand Down
78 changes: 58 additions & 20 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,33 +1838,40 @@ def _storage_type(type: pa.DataType) -> pa.DataType:


@_wrap_for_chunked_arrays
def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
def array_cast(
array: pa.Array, pa_type: pa.DataType, allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True
) -> Union[pa.Array, pa.FixedSizeListArray, pa.ListArray, pa.StructArray, pa.ExtensionArray]:
"""Improved version of `pa.Array.cast`

It supports casting `pa.StructArray` objects to re-order the fields.
It also let you control certain aspects of the casting, e.g. whether
to disable numbers (`floats` or `ints`) to strings.
to disable casting primitives (`booleans`, `floats` or `ints`) or
disable casting decimals to strings.

Args:
array (`pa.Array`):
PyArrow array to cast
pa_type (`pa.DataType`):
Target PyArrow type
allow_number_to_str (`bool`, defaults to `True`):
Whether to allow casting numbers to strings.
allow_primitive_to_str (`bool`, defaults to `True`):
Whether to allow casting primitives to strings.
Defaults to `True`.
allow_decimal_to_str (`bool`, defaults to `True`):
Whether to allow casting decimals to strings.
Defaults to `True`.

Raises:
`pa.ArrowInvalidError`: if the arrow data casting fails
`TypeError`: if the target type is not supported according, e.g.

- if a field is missing
- if casting from numbers to strings and `allow_number_to_str` is `False`
- if casting from primitives to strings and `allow_primitive_to_str` is `False`
- if casting from decimals to strings and `allow_decimal_to_str` is `False`

Returns:
`List[pyarrow.Array]`: the casted array
"""
_c = partial(array_cast, allow_number_to_str=allow_number_to_str)
_c = partial(array_cast, allow_primitive_to_str=allow_primitive_to_str, allow_decimal_to_str=allow_decimal_to_str)
if isinstance(array, pa.ExtensionArray):
array = array.storage
if isinstance(pa_type, pa.ExtensionType):
Expand Down Expand Up @@ -1933,13 +1940,14 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size
return pa.ListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type), mask=array.is_null())
else:
if (
not allow_number_to_str
and pa.types.is_string(pa_type)
and (pa.types.is_floating(array.type) or pa.types.is_integer(array.type))
if pa.types.is_string(pa_type) and (
(not allow_primitive_to_str and pa.types.is_primitive(array.type))
or (not allow_decimal_to_str and pa.types.is_decimal(array.type))
):
raise TypeError(
f"Couldn't cast array of type {array.type} to {pa_type} since allow_number_to_str is set to {allow_number_to_str}"
f"Couldn't cast array of type {array.type} to {pa_type} "
f"since allow_primitive_to_str is set to {allow_primitive_to_str} "
f"and allow_decimal_to_str is set to {allow_decimal_to_str}."
Modexus marked this conversation as resolved.
Show resolved Hide resolved
)
if pa.types.is_null(pa_type) and not pa.types.is_null(array.type):
raise TypeError(f"Couldn't cast array of type {array.type} to {pa_type}")
Expand All @@ -1948,7 +1956,9 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):


@_wrap_for_chunked_arrays
def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_to_str=True):
def cast_array_to_feature(
array: pa.Array, feature: "FeatureType", allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True
) -> pa.Array:
"""Cast an array to the arrow type that corresponds to the requested feature type.
For custom features like [`Audio`] or [`Image`], it takes into account the "cast_storage" methods
they defined to enable casting from other arrow types.
Expand All @@ -1958,23 +1968,31 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
The PyArrow array to cast.
feature (`datasets.features.FeatureType`):
The target feature type.
allow_number_to_str (`bool`, defaults to `True`):
Whether to allow casting numbers to strings.
allow_primitive_to_str (`bool`, defaults to `True`):
Whether to allow casting primitives to strings.
Defaults to `True`.
allow_decimal_to_str (`bool`, defaults to `True`):
Whether to allow casting decimals to strings.
Defaults to `True`.

Raises:
`pa.ArrowInvalidError`: if the arrow data casting fails
`TypeError`: if the target type is not supported according, e.g.

- if a field is missing
- if casting from numbers to strings and `allow_number_to_str` is `False`
- if casting from primitives and `allow_primitive_to_str` is `False`
- if casting from decimals and `allow_decimal_to_str` is `False`

Returns:
array (`pyarrow.Array`): the casted array
"""
from .features.features import Sequence, get_nested_type

_c = partial(cast_array_to_feature, allow_number_to_str=allow_number_to_str)
_c = partial(
cast_array_to_feature,
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)

if isinstance(array, pa.ExtensionArray):
array = array.storage
Expand Down Expand Up @@ -2011,9 +2029,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
storage_type = _storage_type(array_type)
if array_type != storage_type:
# Temporarily convert to the storage type to support extension types in the slice operation
array = array_cast(array, storage_type, allow_number_to_str=allow_number_to_str)
array = array_cast(
array,
storage_type,
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
array = pc.list_slice(array, 0, feature.length, return_fixed_size_list=True)
array = array_cast(array, array_type, allow_number_to_str=allow_number_to_str)
array = array_cast(
array,
array_type,
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
else:
array = pc.list_slice(array, 0, feature.length, return_fixed_size_list=True)
array_values = array.values
Expand Down Expand Up @@ -2069,9 +2097,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size
return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature.feature), mask=array.is_null())
if pa.types.is_null(array.type):
return array_cast(array, get_nested_type(feature), allow_number_to_str=allow_number_to_str)
return array_cast(
array,
get_nested_type(feature),
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
elif not isinstance(feature, (Sequence, dict, list, tuple)):
return array_cast(array, feature(), allow_number_to_str=allow_number_to_str)
return array_cast(
array,
feature(),
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{feature}")


Expand Down
38 changes: 36 additions & 2 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import pickle
from decimal import Decimal
from typing import List, Union

import numpy as np
Expand Down Expand Up @@ -1098,11 +1099,44 @@ def test_indexed_table_mixin():
assert table.fast_slice(2, 13) == pa_table.slice(2, 13)


def test_cast_array_to_features():
def test_cast_integer_array_to_features():
arr = pa.array([[0, 1]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False)


def test_cast_float_array_to_features():
arr = pa.array([[0.0, 1.0]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False)


def test_cast_boolean_array_to_features():
arr = pa.array([[False, True]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False)


def test_cast_decimal_array_to_features():
arr = pa.array([[Decimal(0), Decimal(1)]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_number_to_str=False)
cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False)


def test_cast_array_to_features_nested():
Expand Down
Loading