From 88f646c418b408ace2494c02b9502f516a565e2b Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 26 Aug 2024 06:26:02 +0200 Subject: [PATCH] Rename LargeList.dtype to LargeList.feature (#7106) --- src/datasets/features/features.py | 38 +++++++++++++++---------------- src/datasets/table.py | 8 ++++--- tests/features/test_features.py | 8 +++---- tests/test_info.py | 2 +- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 4f0a75c4753..1d241e0b7b7 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1175,11 +1175,11 @@ class LargeList: It is backed by `pyarrow.LargeListType`, which is like `pyarrow.ListType` but with 64-bit rather than 32-bit offsets. Args: - dtype ([`FeatureType`]): + feature ([`FeatureType`]): Child feature data type of each item within the large list. """ - dtype: Any + feature: Any id: Optional[str] = None # Automatically constructed pa_type: ClassVar[Any] = None @@ -1218,8 +1218,6 @@ def _check_non_null_non_empty_recursive(obj, schema: Optional[FeatureType] = Non pass elif isinstance(schema, (list, tuple)): schema = schema[0] - elif isinstance(schema, LargeList): - schema = schema.dtype else: schema = schema.feature return _check_non_null_non_empty_recursive(obj[0], schema) @@ -1252,7 +1250,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType: value_type = get_nested_type(schema[0]) return pa.list_(value_type) elif isinstance(schema, LargeList): - value_type = get_nested_type(schema.dtype) + value_type = get_nested_type(schema.feature) return pa.large_list(value_type) elif isinstance(schema, Sequence): value_type = get_nested_type(schema.feature) @@ -1303,7 +1301,7 @@ def encode_nested_example(schema, obj, level=0): return None else: if len(obj) > 0: - sub_schema = schema.dtype + sub_schema = schema.feature for first_elmt in obj: if _check_non_null_non_empty_recursive(first_elmt, sub_schema): break @@ -1384,7 +1382,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni if obj is None: return None else: - sub_schema = schema.dtype + sub_schema = schema.feature if len(obj) > 0: for first_elmt in obj: if _check_non_null_non_empty_recursive(first_elmt, sub_schema): @@ -1463,8 +1461,8 @@ def generate_from_dict(obj: Any): raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}") if class_type == LargeList: - dtype = obj.pop("dtype") - return LargeList(generate_from_dict(dtype), **obj) + feature = obj.pop("feature") + return LargeList(feature=generate_from_dict(feature), **obj) if class_type == Sequence: feature = obj.pop("feature") return Sequence(feature=generate_from_dict(feature), **obj) @@ -1493,8 +1491,8 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType: return [feature] return Sequence(feature=feature) elif isinstance(pa_type, pa.LargeListType): - dtype = generate_from_arrow_type(pa_type.value_type) - return LargeList(dtype) + feature = generate_from_arrow_type(pa_type.value_type) + return LargeList(feature=feature) elif isinstance(pa_type, _ArrayXDExtensionType): array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims] return array_feature(shape=pa_type.shape, dtype=pa_type.value_type) @@ -1601,7 +1599,7 @@ def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureT elif isinstance(feature, (list, tuple)): out = func([_visit(feature[0], func)]) elif isinstance(feature, LargeList): - out = func(LargeList(_visit(feature.dtype, func))) + out = func(LargeList(_visit(feature.feature, func))) elif isinstance(feature, Sequence): out = func(Sequence(_visit(feature.feature, func), length=feature.length)) else: @@ -1624,7 +1622,7 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False elif isinstance(feature, (list, tuple)): return require_decoding(feature[0]) elif isinstance(feature, LargeList): - return require_decoding(feature.dtype) + return require_decoding(feature.feature) elif isinstance(feature, Sequence): return require_decoding(feature.feature) else: @@ -1644,7 +1642,7 @@ def require_storage_cast(feature: FeatureType) -> bool: elif isinstance(feature, (list, tuple)): return require_storage_cast(feature[0]) elif isinstance(feature, LargeList): - return require_storage_cast(feature.dtype) + return require_storage_cast(feature.feature) elif isinstance(feature, Sequence): return require_storage_cast(feature.feature) else: @@ -1664,7 +1662,7 @@ def require_storage_embed(feature: FeatureType) -> bool: elif isinstance(feature, (list, tuple)): return require_storage_cast(feature[0]) elif isinstance(feature, LargeList): - return require_storage_cast(feature.dtype) + return require_storage_cast(feature.feature) elif isinstance(feature, Sequence): return require_storage_cast(feature.feature) else: @@ -1876,8 +1874,8 @@ def to_yaml_inner(obj: Union[dict, list]) -> dict: if isinstance(obj, dict): _type = obj.pop("_type", None) if _type == "LargeList": - value_type = obj.pop("dtype") - return simplify({"large_list": to_yaml_inner(value_type), **obj}) + _feature = obj.pop("feature") + return simplify({"large_list": to_yaml_inner(_feature), **obj}) elif _type == "Sequence": _feature = obj.pop("feature") return simplify({"sequence": to_yaml_inner(_feature), **obj}) @@ -1947,8 +1945,8 @@ def from_yaml_inner(obj: Union[dict, list]) -> Union[dict, list]: return {} _type = next(iter(obj)) if _type == "large_list": - _dtype = unsimplify(obj).pop(_type) - return {"dtype": from_yaml_inner(_dtype), **obj, "_type": "LargeList"} + _feature = unsimplify(obj).pop(_type) + return {"feature": from_yaml_inner(_feature), **obj, "_type": "LargeList"} if _type == "sequence": _feature = unsimplify(obj).pop(_type) return {"feature": from_yaml_inner(_feature), **obj, "_type": "Sequence"} @@ -2180,7 +2178,7 @@ def recursive_reorder(source, target, stack=""): elif isinstance(source, LargeList): if not isinstance(target, LargeList): raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position) - return LargeList(recursive_reorder(source.dtype, target.dtype, stack)) + return LargeList(recursive_reorder(source.feature, target.feature, stack)) else: return source diff --git a/src/datasets/table.py b/src/datasets/table.py index e9e27544220..bfa63173ef9 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2017,7 +2017,7 @@ def cast_array_to_feature( array_offsets = _combine_list_array_offsets_with_mask(array) return pa.ListArray.from_arrays(array_offsets, casted_array_values) elif isinstance(feature, LargeList): - casted_array_values = _c(array.values, feature.dtype) + casted_array_values = _c(array.values, feature.feature) if pa.types.is_large_list(array.type) and casted_array_values.type == array.values.type: # Both array and feature have equal large_list type and values (within the list) type return array @@ -2075,7 +2075,9 @@ def cast_array_to_feature( return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature[0]), mask=array.is_null()) elif isinstance(feature, LargeList): array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size - return pa.LargeListArray.from_arrays(array_offsets, _c(array.values, feature.dtype), mask=array.is_null()) + return pa.LargeListArray.from_arrays( + array_offsets, _c(array.values, feature.feature), mask=array.is_null() + ) elif isinstance(feature, Sequence): if feature.length > -1: if feature.length == array.type.list_size: @@ -2155,7 +2157,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"): # feature must be LargeList(subfeature) # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError array_offsets = _combine_list_array_offsets_with_mask(array) - return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.dtype)) + return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.feature)) elif pa.types.is_fixed_size_list(array.type): # feature must be Sequence(subfeature) if isinstance(feature, Sequence) and feature.length > -1: diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 8ab4baced7a..6234d7ede62 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -726,7 +726,7 @@ def test_features_flatten_with_list_types(features_dict, expected_features_dict) {"col": [Value("int32")]}, ), ( - {"col": {"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}}, + {"col": {"feature": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}}, {"col": LargeList(Value("int32"))}, ), ( @@ -738,7 +738,7 @@ def test_features_flatten_with_list_types(features_dict, expected_features_dict) {"col": [{"sub_col": Value("int32")}]}, ), ( - {"col": {"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}}, + {"col": {"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}}, {"col": LargeList({"sub_col": Value("int32")})}, ), ( @@ -760,7 +760,7 @@ def test_features_from_dict_with_list_types(deserialized_features_dict, expected [Value("int32")], ), ( - {"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}, + {"feature": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}, LargeList(Value("int32")), ), ( @@ -772,7 +772,7 @@ def test_features_from_dict_with_list_types(deserialized_features_dict, expected [{"sub_col": Value("int32")}], ), ( - {"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}, + {"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}, LargeList({"sub_col": Value("int32")}), ), ( diff --git a/tests/test_info.py b/tests/test_info.py index 1657753b9ec..cefc39419b0 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -170,7 +170,7 @@ def test_dataset_info_from_dict_with_large_list(): dataset_info_dict = { "citation": "", "description": "", - "features": {"col_1": {"dtype": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}}, + "features": {"col_1": {"feature": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}}, "homepage": "", "license": "", }