Skip to content

Commit

Permalink
Add serialize_example / deserialize_example to FeatureConnector
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389120936
  • Loading branch information
Conchylicultor authored and copybara-github committed Aug 6, 2021
1 parent 8de1f50 commit afb6bd0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 16 deletions.
54 changes: 53 additions & 1 deletion tensorflow_datasets/core/features/top_level_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

"""Wrapper around FeatureDict to allow better control over decoding."""

from typing import Union

import tensorflow as tf
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.features import feature as feature_lib


Expand All @@ -28,7 +32,7 @@ class TopLevelFeature(feature_lib.FeatureConnector):
eventually better support for augmentations.
"""

def decode_example(self, serialized_example, decoders=None):
def decode_example(self, serialized_example, *, decoders=None):
# pylint: disable=line-too-long
"""Decode the serialize examples.
Expand Down Expand Up @@ -69,6 +73,54 @@ def decode_example(self, serialized_example, decoders=None):
nested_decoded = self._nest(flatten_decoded)
return nested_decoded

def serialize_example(self, example_data) -> bytes:
"""Encodes nested data values into `tf.train.Example` bytes.
See `deserialize_example` to decode the proto into `tf.Tensor`.
Args:
example_data: Example data to encode (numpy-like nested dict)
Returns:
The serialized `tf.train.Example`.
"""
example_data = self.encode_example(example_data)
return self._example_serializer.serialize_example(example_data)

def deserialize_example(
self,
serialized_example: Union[tf.Tensor, bytes],
*,
decoders=None,
) -> utils.TensorDict:
"""Decodes the `tf.train.Example` data into `tf.Tensor`.
See `serialize_example` to encode the data into proto.
Args:
serialized_example: The tensor-like object containing the serialized
`tf.train.Example` proto.
decoders: Eventual decoders to apply (see
[documentation](https://www.tensorflow.org/datasets/decode))
Returns:
The decoded features tensors.
"""
example_data = self._example_parser.parse_example(serialized_example)
return self.decode_example(example_data, decoders=decoders)

@utils.memoized_property
def _example_parser(self):
from tensorflow_datasets.core import example_parser # pytype: disable=import-error # pylint: disable=g-import-not-at-top
example_specs = self.get_serialized_info()
return example_parser.ExampleParser(example_specs)

@utils.memoized_property
def _example_serializer(self):
from tensorflow_datasets.core import example_serializer # pytype: disable=import-error # pylint: disable=g-import-not-at-top
example_specs = self.get_serialized_info()
return example_serializer.ExampleSerializer(example_specs)


def _decode_feature(feature, example, serialized_info, decoder):
"""Decode a single feature."""
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_datasets/core/utils/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@

Tensor = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]

# Nested dict of tensor
TensorDict = TreeDict[Tensor]

Dim = Optional[int]
Shape = TupleOrList[Dim]

Expand Down
19 changes: 4 additions & 15 deletions tensorflow_datasets/testing/feature_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
"""Test case util to test `tfds.features.FeatureConnector`."""

import contextlib
import dataclasses
import functools
from typing import Any, Optional, Type

import dataclasses
import dill
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import example_parser
from tensorflow_datasets.core import example_serializer
from tensorflow_datasets.core import features
from tensorflow_datasets.core import utils
from tensorflow_datasets.testing import test_case
Expand Down Expand Up @@ -345,23 +343,14 @@ def _test_repr(

def features_encode_decode(features_dict, example, decoders):
"""Runs the full pipeline: encode > write > tmp files > read > decode."""
# Encode example
encoded_example = features_dict.encode_example(example)

# Serialize/deserialize the example
specs = features_dict.get_serialized_info()
serializer = example_serializer.ExampleSerializer(specs)
parser = example_parser.ExampleParser(specs)
serialized_example = features_dict.serialize_example(example)

serialized_example = serializer.serialize_example(encoded_example)
ds = tf.data.Dataset.from_tensors(serialized_example)
ds = ds.map(parser.parse_example)

# Decode the example
decode_fn = functools.partial(
features_dict.decode_example,
features_dict.deserialize_example,
decoders=decoders,
)
ds = tf.data.Dataset.from_tensors(serialized_example)
ds = ds.map(decode_fn)

if tf.executing_eagerly():
Expand Down

0 comments on commit afb6bd0

Please sign in to comment.