diff --git a/tensorflow_datasets/core/features/top_level_feature.py b/tensorflow_datasets/core/features/top_level_feature.py index 14ba79a0ff3..2a9bb762909 100644 --- a/tensorflow_datasets/core/features/top_level_feature.py +++ b/tensorflow_datasets/core/features/top_level_feature.py @@ -15,6 +15,10 @@ """Wrapper around FeatureDict to allow better control over decoding.""" +from typing import Union + +import tensorflow as tf +from tensorflow_datasets.core import utils from tensorflow_datasets.core.features import feature as feature_lib @@ -28,7 +32,7 @@ class TopLevelFeature(feature_lib.FeatureConnector): eventually better support for augmentations. """ - def decode_example(self, serialized_example, decoders=None): + def decode_example(self, serialized_example, *, decoders=None): # pylint: disable=line-too-long """Decode the serialize examples. @@ -69,6 +73,54 @@ def decode_example(self, serialized_example, decoders=None): nested_decoded = self._nest(flatten_decoded) return nested_decoded + def serialize_example(self, example_data) -> bytes: + """Encodes nested data values into `tf.train.Example` bytes. + + See `deserialize_example` to decode the proto into `tf.Tensor`. + + Args: + example_data: Example data to encode (numpy-like nested dict) + + Returns: + The serialized `tf.train.Example`. + """ + example_data = self.encode_example(example_data) + return self._example_serializer.serialize_example(example_data) + + def deserialize_example( + self, + serialized_example: Union[tf.Tensor, bytes], + *, + decoders=None, + ) -> utils.TensorDict: + """Decodes the `tf.train.Example` data into `tf.Tensor`. + + See `serialize_example` to encode the data into proto. + + Args: + serialized_example: The tensor-like object containing the serialized + `tf.train.Example` proto. + decoders: Eventual decoders to apply (see + [documentation](https://www.tensorflow.org/datasets/decode)) + + Returns: + The decoded features tensors. + """ + example_data = self._example_parser.parse_example(serialized_example) + return self.decode_example(example_data, decoders=decoders) + + @utils.memoized_property + def _example_parser(self): + from tensorflow_datasets.core import example_parser # pytype: disable=import-error # pylint: disable=g-import-not-at-top + example_specs = self.get_serialized_info() + return example_parser.ExampleParser(example_specs) + + @utils.memoized_property + def _example_serializer(self): + from tensorflow_datasets.core import example_serializer # pytype: disable=import-error # pylint: disable=g-import-not-at-top + example_specs = self.get_serialized_info() + return example_serializer.ExampleSerializer(example_specs) + def _decode_feature(feature, example, serialized_info, decoder): """Decode a single feature.""" diff --git a/tensorflow_datasets/core/utils/type_utils.py b/tensorflow_datasets/core/utils/type_utils.py index 309cd487c44..163063055e5 100644 --- a/tensorflow_datasets/core/utils/type_utils.py +++ b/tensorflow_datasets/core/utils/type_utils.py @@ -45,6 +45,9 @@ Tensor = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] +# Nested dict of tensor +TensorDict = TreeDict[Tensor] + Dim = Optional[int] Shape = TupleOrList[Dim] diff --git a/tensorflow_datasets/testing/feature_test_case.py b/tensorflow_datasets/testing/feature_test_case.py index 4987e4d9dd3..c6abc62014a 100644 --- a/tensorflow_datasets/testing/feature_test_case.py +++ b/tensorflow_datasets/testing/feature_test_case.py @@ -16,17 +16,15 @@ """Test case util to test `tfds.features.FeatureConnector`.""" import contextlib +import dataclasses import functools from typing import Any, Optional, Type -import dataclasses import dill import numpy as np import tensorflow.compat.v2 as tf from tensorflow_datasets.core import dataset_utils -from tensorflow_datasets.core import example_parser -from tensorflow_datasets.core import example_serializer from tensorflow_datasets.core import features from tensorflow_datasets.core import utils from tensorflow_datasets.testing import test_case @@ -345,23 +343,14 @@ def _test_repr( def features_encode_decode(features_dict, example, decoders): """Runs the full pipeline: encode > write > tmp files > read > decode.""" - # Encode example - encoded_example = features_dict.encode_example(example) - # Serialize/deserialize the example - specs = features_dict.get_serialized_info() - serializer = example_serializer.ExampleSerializer(specs) - parser = example_parser.ExampleParser(specs) + serialized_example = features_dict.serialize_example(example) - serialized_example = serializer.serialize_example(encoded_example) - ds = tf.data.Dataset.from_tensors(serialized_example) - ds = ds.map(parser.parse_example) - - # Decode the example decode_fn = functools.partial( - features_dict.decode_example, + features_dict.deserialize_example, decoders=decoders, ) + ds = tf.data.Dataset.from_tensors(serialized_example) ds = ds.map(decode_fn) if tf.executing_eagerly():