Skip to content

Commit

Permalink
Uniformly use TFXIO's GetBatchElementsKwargs.
Browse files Browse the repository at this point in the history
tfx_bsl: Also introduced a cap for the batch size. This is to make sure Beam's auto batching does not produce a batch that's too large that ListArray/BinaryArray cannot fit. This cap can be removed when LargeList/LargeBinary are produced.

tfdv: also removed TFExampleDecoder from the public API.
PiperOrigin-RevId: 294956495
  • Loading branch information
brills authored and tfx-copybara committed Feb 13, 2020
1 parent 9e0273e commit e7dceaf
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 162 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
## Breaking Changes

## Deprecations
TFExampleDecoder

# Release 0.21.1

Expand Down
12 changes: 8 additions & 4 deletions g3doc/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,16 @@ those examples.
TFDV also provides the `validate_instance` function for identifying whether an
individual example exhibits anomalies when matched against a schema. To use this
function, the example must be a dict mapping feature names to numpy arrays of
feature values. You can use the `TFExampleDecoder` to decode serialized
feature values. You can use the the decoder in `tfx_bsl` to decode serialized
`tf.train.Example`s into this format. For example:

```python
decoder = tfdv.TFExampleDecoder()
example = decoder.decode(serialized_tfexample)
import tensorflow_data_validation as tfdv
import tfx_bsl
import pyarrow as pa
decoder = tfx_bsl.coders.example_coder.ExamplesToRecordBatchDecoder()
example = pa.Table.from_batches(
[decoder.DecodeBatch([serialized_tfexample])])
options = tfdv.StatsOptions(schema=schema)
anomalies = tfdv.validate_instance(example, options)
```
Expand Down Expand Up @@ -408,7 +412,7 @@ Once you have implemented the custom data connector that batches your
input examples in an Arrow table, you need to connect it with the
`tfdv.GenerateStatistics` API for computing the data statistics. Take `TFRecord`
of `tf.train.Example`'s for example. We provide the
[TFExampleDecoder](https://github.com/tensorflow/data-validation/tree/master/tensorflow_data_validation/coders/tf_example_decoder.py)
[DecodeTFExample](https://github.com/tensorflow/data-validation/tree/master/tensorflow_data_validation/coders/tf_example_decoder.py)
data connector, and below is an example of how to connect it with the
`tfdv.GenerateStatistics` API.

Expand Down
1 change: 0 additions & 1 deletion tensorflow_data_validation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
# Import coders.
from tensorflow_data_validation.coders.csv_decoder import DecodeCSV
from tensorflow_data_validation.coders.tf_example_decoder import DecodeTFExample
from tensorflow_data_validation.coders.tf_example_decoder import TFExampleDecoder

# Import stats generators.
from tensorflow_data_validation.statistics.generators.lift_stats_generator import LiftStatsGenerator
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_data_validation/api/stats_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class GenerateStatistics(beam.PTransform):
with beam.Pipeline(runner=...) as p:
_ = (p
| 'ReadData' >> beam.io.ReadFromTFRecord(data_location)
| 'DecodeData' >> beam.Map(TFExampleDecoder().decode)
| 'DecodeData' >> tfdv.DecodeTFExample()
| 'GenerateStatistics' >> GenerateStatistics()
| 'WriteStatsOutput' >> beam.io.WriteToTFRecord(
output_path, shard_name_template='',
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_data_validation/coders/csv_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import pyarrow as pa
from tensorflow_data_validation import constants
from tensorflow_data_validation import types
from tensorflow_data_validation.utils import batch_util
from tfx_bsl.coders import csv_decoder as csv_decoder
from tfx_bsl.tfxio import record_based_tfxio

This comment has been minimized.

Copy link
@Ark-kun

Ark-kun Feb 20, 2020

As of now, record_based_tfxio is not available in any released version of the tfx_bsl package.

from typing import List, Iterable, Optional, Text

from tensorflow_metadata.proto.v0 import schema_pb2
Expand Down Expand Up @@ -101,7 +101,8 @@ def expand(self, lines: beam.pvalue.PCollection):
# Do second pass to generate the in-memory dict representation.
return (csv_lines
| 'BatchCSVLines' >> beam.BatchElements(
**batch_util.GetBeamBatchKwargs(self._desired_batch_size))
**record_based_tfxio.GetBatchElementsKwargs(
self._desired_batch_size))
| 'BatchedCSVRowsToArrow' >> beam.ParDo(
_BatchedCSVRowsToArrow(skip_blank_lines=self._skip_blank_lines),
column_infos))
Expand Down
12 changes: 0 additions & 12 deletions tensorflow_data_validation/coders/tf_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,7 @@
import apache_beam as beam
import pyarrow as pa
from tensorflow_data_validation import constants
from tensorflow_data_validation import types
from tensorflow_data_validation.utils import batch_util
from tfx_bsl.coders import example_coder


# TODO(pachristopher): Deprecate this in 0.16.
class TFExampleDecoder(object):
"""A decoder for decoding TF examples into tf data validation datasets.
"""

def decode(self, serialized_example_proto: bytes) -> types.Example:
"""Decodes serialized tf.Example to tf data validation input dict."""
return example_coder.ExampleToNumpyDict(serialized_example_proto)


@beam.ptransform_fn
Expand Down
37 changes: 0 additions & 37 deletions tensorflow_data_validation/coders/tf_example_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
from __future__ import division
from __future__ import print_function

import sys
from absl.testing import absltest
from absl.testing import parameterized
import apache_beam as beam
from apache_beam.testing import util
import numpy as np
import tensorflow as tf
from tensorflow_data_validation.coders import tf_example_decoder
from tensorflow_data_validation.coders import tf_example_decoder_test_data
Expand All @@ -33,25 +31,6 @@
class TFExampleDecoderTest(parameterized.TestCase):
"""Tests for TFExampleDecoder."""

def _check_decoding_results(self, actual, expected):
# Check that the numpy array dtypes match.
self.assertEqual(len(actual), len(expected))
for key in actual:
if expected[key] is None:
self.assertEqual(actual[key], None)
else:
self.assertEqual(actual[key].dtype, expected[key].dtype)
np.testing.assert_equal(actual, expected)

@parameterized.named_parameters(
*tf_example_decoder_test_data.TF_EXAMPLE_DECODER_TESTS)
def test_decode_example(self, example_proto_text, decoded_example):
example = tf.train.Example()
text_format.Merge(example_proto_text, example)
decoder = tf_example_decoder.TFExampleDecoder()
self._check_decoding_results(
decoder.decode(example.SerializeToString()), decoded_example)

@parameterized.named_parameters(
*tf_example_decoder_test_data.BEAM_TF_EXAMPLE_DECODER_TESTS)
def test_decode_example_with_beam_pipeline(self, example_proto_text,
Expand All @@ -66,21 +45,5 @@ def test_decode_example_with_beam_pipeline(self, example_proto_text,
result,
test_util.make_arrow_tables_equal_fn(self, [decoded_table]))

def test_decode_example_none_ref_count(self):
example = text_format.Parse(
'''
features {
feature {
key: 'x'
value { }
}
}
''', tf.train.Example())
before_refcount = sys.getrefcount(None)
_ = tf_example_decoder.TFExampleDecoder().decode(
example.SerializeToString())
after_refcount = sys.getrefcount(None)
self.assertEqual(before_refcount + 1, after_refcount)

if __name__ == '__main__':
absltest.main()
94 changes: 0 additions & 94 deletions tensorflow_data_validation/coders/tf_example_decoder_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,103 +16,9 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import pyarrow as pa


TF_EXAMPLE_DECODER_TESTS = [
{
'testcase_name': 'empty_input',
'example_proto_text': '''features {}''',
'decoded_example': {}
},
{
'testcase_name': 'int_feature_non_empty',
'example_proto_text': '''
features {
feature {
key: 'x'
value { int64_list { value: [ 1, 2, 3 ] } }
}
}
''',
'decoded_example': {'x': np.array([1, 2, 3], dtype=np.int64)}
},
{
'testcase_name': 'float_feature_non_empty',
'example_proto_text': '''
features {
feature {
key: 'x'
value { float_list { value: [ 4.0, 5.0 ] } }
}
}
''',
'decoded_example': {'x': np.array([4.0, 5.0], dtype=np.float32)}
},
{
'testcase_name': 'str_feature_non_empty',
'example_proto_text': '''
features {
feature {
key: 'x'
value { bytes_list { value: [ 'string', 'list' ] } }
}
}
''',
'decoded_example': {'x': np.array([b'string', b'list'],
dtype=np.object)}
},
{
'testcase_name': 'int_feature_empty',
'example_proto_text': '''
features {
feature {
key: 'x'
value { int64_list { } }
}
}
''',
'decoded_example': {'x': np.array([], dtype=np.int64)}
},
{
'testcase_name': 'float_feature_empty',
'example_proto_text': '''
features {
feature {
key: 'x'
value { float_list { } }
}
}
''',
'decoded_example': {'x': np.array([], dtype=np.float32)}
},
{
'testcase_name': 'str_feature_empty',
'example_proto_text': '''
features {
feature {
key: 'x'
value { bytes_list { } }
}
}
''',
'decoded_example': {'x': np.array([], dtype=np.object)}
},
{
'testcase_name': 'feature_missing',
'example_proto_text': '''
features {
feature {
key: 'x'
value { }
}
}
''',
'decoded_example': {'x': None}
},
]

BEAM_TF_EXAMPLE_DECODER_TESTS = [
{
'testcase_name': 'beam_test',
Expand Down
19 changes: 8 additions & 11 deletions tensorflow_data_validation/utils/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,12 @@
from tensorflow_data_validation import types
from tensorflow_data_validation.arrow import decoded_examples_to_arrow
from tfx_bsl.coders import example_coder
from typing import Dict, List, Iterable, Optional, Text
from tfx_bsl.tfxio import record_based_tfxio
from typing import List, Iterable, Optional


def GetBeamBatchKwargs(desired_batch_size: Optional[int]) -> Dict[Text, int]:
"""Returns the kwargs to be passed to beam.BatchElements."""
if desired_batch_size is None:
return {}
return {
"min_batch_size": desired_batch_size,
"max_batch_size": desired_batch_size,
}
# DEPRECATED. Use the TFXIO util instead.
GetBeamBatchKwargs = record_based_tfxio.GetBatchElementsKwargs


# TODO(pachristopher): Deprecate this.
Expand Down Expand Up @@ -65,7 +60,8 @@ def BatchExamplesToArrowTables(
return (
examples
| "BatchBeamExamples" >>
beam.BatchElements(**GetBeamBatchKwargs(desired_batch_size))
beam.BatchElements(**record_based_tfxio.GetBatchElementsKwargs(
desired_batch_size))
| "DecodeExamplesToTable" >>
# pylint: disable=unnecessary-lambda
beam.Map(lambda x: decoded_examples_to_arrow.DecodedExamplesToTable(x)))
Expand All @@ -92,7 +88,8 @@ def BatchSerializedExamplesToArrowTables(
"""
return (examples
| "BatchSerializedExamples" >>
beam.BatchElements(**GetBeamBatchKwargs(desired_batch_size))
beam.BatchElements(
**record_based_tfxio.GetBatchElementsKwargs(desired_batch_size))
| "BatchDecodeExamples" >> beam.ParDo(_BatchDecodeExamplesDoFn()))


Expand Down

0 comments on commit e7dceaf

Please sign in to comment.