Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to Sequence Examples #6003

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfx/components/example_gen/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
data_types.RuntimeParameter]] = None,
range_config: Optional[Union[range_config_pb2.RangeConfig,
data_types.RuntimeParameter]] = None,
output_data_format: Optional[int] = example_gen_pb2.FORMAT_TF_EXAMPLE,
output_data_format: Optional[int] = example_gen_pb2.FORMAT_TF_SEQUENCE_EXAMPLE,
output_file_format: Optional[int] = example_gen_pb2.FORMAT_TFRECORDS_GZIP,
custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
"""Construct a FileBasedExampleGen component.
Expand Down
19 changes: 16 additions & 3 deletions tfx/components/example_gen/csv_example_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _bytes_handler(cell: csv_decoder.CSVCell) -> tf.train.Feature:

@beam.typehints.with_input_types(List[csv_decoder.CSVCell],
List[csv_decoder.ColumnInfo])
@beam.typehints.with_output_types(tf.train.Example)
@beam.typehints.with_output_types(tf.train.SequenceExample)
class _ParsedCsvToTfExample(beam.DoFn):
"""A beam.DoFn to convert a parsed CSV line to a tf.Example."""

Expand Down Expand Up @@ -89,8 +89,21 @@ def process(
self._column_handlers):
feature[column_name] = (
handler_fn(csv_cell) if handler_fn else tf.train.Feature())

yield tf.train.Example(features=tf.train.Features(feature=feature))

sequence_features = {
k:v for k,v in feature.items()
if isinstance(v, tf.train.FeatureList)
}

context_features = {
k:v for k,v in feature.items()
if k not in sequence_features.keys()
}

yield tf.train.SequenceExample(
context=tf.train.Features(feature=context_features),
feature_lists=tf.train.FeatureLists(feature_list=sequence_features)
)


class _CsvLineBuffer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(tf.train.Example)
@beam.typehints.with_output_types(tf.train.SequenceExample)
def _AvroToExample( # pylint: disable=invalid-name
pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
split_pattern: str) -> beam.pvalue.PCollection:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(tf.train.Example)
@beam.typehints.with_output_types(tf.train.SequenceExample)
def _ParquetToExample( # pylint: disable=invalid-name
pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
split_pattern: str) -> beam.pvalue.PCollection:
Expand Down
2 changes: 1 addition & 1 deletion tfx/components/example_gen/import_example_gen/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
data_types.RuntimeParameter]] = None,
range_config: Optional[Union[range_config_pb2.RangeConfig,
data_types.RuntimeParameter]] = None,
payload_format: Optional[int] = example_gen_pb2.FORMAT_TF_EXAMPLE):
payload_format: Optional[int] = example_gen_pb2.FORMAT_TF_SEQUENCE_EXAMPLE):
"""Construct an ImportExampleGen component.

Args:
Expand Down
19 changes: 16 additions & 3 deletions tfx/components/example_gen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def pyval_to_feature(pyval: List[Any]) -> feature_pb2.Feature:
)


def dict_to_example(instance: Dict[str, Any]) -> example_pb2.Example:
def dict_to_example(instance: Dict[str, Any]) -> example_pb2.SequenceExample:
"""Converts dict to tf example."""
feature = {}
for key, value in instance.items():
Expand All @@ -128,8 +128,21 @@ def dict_to_example(instance: Dict[str, Any]) -> example_pb2.Example:
feature[key] = pyval_to_feature(pyval)
else:
raise RuntimeError(f'Value type {type(value[0])} is not supported.')

return example_pb2.Example(features=feature_pb2.Features(feature=feature))

sequence_features = {
k:v for k,v in feature.items()
if isinstance(v, feature_pb2.FeatureList)
}

context_features = {
k:v for k,v in feature.items()
if k not in sequence_features.keys()
}

return example_pb2.SequenceExample(
context=feature_pb2.Features(feature=context_features),
feature_lists=feature_pb2.FeatureLists(feature_list=sequence_features)
)


def generate_output_split_names(
Expand Down
10 changes: 10 additions & 0 deletions tfx/components/schema_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from tfx.utils import io_utils
from tfx.utils import json_utils

from tfx_bsl.tfxio.tensor_representaiton_util import (
InferTensorRepresentationsFromSchema,
SetTensorRepresentationsInSchema
)


# Default file name for generated schema file.
DEFAULT_FILE_NAME = 'schema.pbtxt'
Expand Down Expand Up @@ -89,5 +94,10 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
artifact_utils.get_single_uri(
output_dict[standard_component_specs.SCHEMA_KEY]),
DEFAULT_FILE_NAME)

# Add tensor representations to handle SequenceExamples downstream.
tensor_representations = InferTensorRepresentationsFromSchema(schema)
SetTensorRepresentationsInSchema(schema, tensor_representations)

io_utils.write_pbtxt_file(output_uri, schema)
logging.info('Schema written to %s.', output_uri)
16 changes: 14 additions & 2 deletions tfx/extensions/google_cloud_big_query/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def ReadFromBigQuery( # pylint: disable=invalid-name

def row_to_example( # pylint: disable=invalid-name
field_to_type: Dict[str, str],
field_name_to_data: Dict[str, Any]) -> tf.train.Example:
field_name_to_data: Dict[str, Any]) -> tf.train.SequenceExample:
"""Convert bigquery result row to tf example.

Args:
Expand Down Expand Up @@ -96,7 +96,19 @@ def row_to_example( # pylint: disable=invalid-name
'BigQuery column "{}" has non-supported type {}.'.format(key,
data_type))

return tf.train.Example(features=tf.train.Features(feature=feature))
sequence_features = {
k:v for k,v in feature.items()
if isinstance(v, tf.train.FeatureList)
}
context_features = {
k:v for k,v in feature.items()
if k not in sequence_features.keys()
}

return tf.train.SequenceExample(
context=tf.train.Features(feature=context_features),
feature_lists=tf.train.FeatureLists(feature_list=sequence_features)
)


def parse_gcp_project(beam_pipeline_args: List[str]) -> str:
Expand Down