Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 304015056
  • Loading branch information
tfx-copybara authored and tensorflow-extended-team committed Mar 31, 2020
1 parent facdbf2 commit a321c4e
Show file tree
Hide file tree
Showing 6 changed files with 451 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tfx/scripts/ai_platform_entrypoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used in mp_run_executor.py."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
from typing import Any, Dict, List, Text
from google.protobuf import json_format
from ml_metadata.proto import metadata_store_pb2
from tfx.types import artifact
from tfx.types import artifact_utils


def parse_raw_artifact_dict(
inputs_dict: Dict[Text, Any]) -> Dict[Text, List[artifact.Artifact]]:
"""Parses a dict from key to a list of a single Artifact from a nested dict."""
result = {}
for k, v in inputs_dict.items():
result[k] = [
_parse_raw_artifact(single_artifact)
for single_artifact in v['artifacts']
]
return result


def parse_execution_properties(dict_data: Dict[Text, Any]) -> Dict[Text, Any]:
"""Parses a dict from key to Value proto as execution properties."""
result = {}
for k, v in dict_data.items():
# Translate each field from Value pb to plain value.
value_pb = metadata_store_pb2.Value()
json_format.Parse(json.dumps(v), value_pb)
result[k] = getattr(value_pb, value_pb.WhichOneof('value'))
if result[k] is None:
raise TypeError('Unrecognized type encountered at field %s of execution'
' properties %s' % (k, dict_data))

return result


def _parse_raw_artifact(dict_data: Dict[Text, Any]) -> artifact.Artifact:
"""Parses json serialized version of artifact without artifact_type."""
# This parser can only reserve what's inside artifact pb message.
artifact_pb = metadata_store_pb2.Artifact()
# TODO(b/152444458): For compatibility, current TFX serialization assumes
# there is no type field in Artifact pb message.
type_name = dict_data.pop('type')
json_format.Parse(json.dumps(dict_data), artifact_pb)

# Make an ArtifactType pb according to artifact_pb
type_pb = metadata_store_pb2.ArtifactType()
type_pb.name = type_name
for k, v in artifact_pb.properties.items():
if v.HasField('int_value'):
type_pb.properties[k] = metadata_store_pb2.PropertyType.INT
elif v.HasField('string_value'):
type_pb.properties[k] = metadata_store_pb2.PropertyType.STRING
elif v.HasField('double_value'):
type_pb.properties[k] = metadata_store_pb2.PropertyType.DOUBLE
else:
raise ValueError('Unrecognized type encountered at field %s' % (k))

result = artifact_utils.deserialize_artifact(type_pb, artifact_pb)
return result
89 changes: 89 additions & 0 deletions tfx/scripts/ai_platform_entrypoint_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Lint as: python2, python3
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.scripts.entrypoint_utils."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
from typing import Any, Dict, Text
import tensorflow as tf

from tfx.scripts import ai_platform_entrypoint_utils
from tfx.types import standard_artifacts

_ARTIFACT_1 = standard_artifacts.StringType()
_KEY_1 = 'input_1'

_ARTIFACT_2 = standard_artifacts.ModelBlessing()
_KEY_2 = 'input_2'

_EXEC_PROPERTIES = {
'input_config': 'input config string',
'output_config':
'{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": '
'\"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }',
}


class EntrypointUtilsTest(tf.test.TestCase):

def setUp(self):
super(EntrypointUtilsTest, self).setUp()
_ARTIFACT_1.type_id = 1
_ARTIFACT_1.uri = 'gs://root/string/'
_ARTIFACT_2.type_id = 2
_ARTIFACT_2.uri = 'gs://root/model/'
self._expected_dict = {
_KEY_1: [_ARTIFACT_1],
_KEY_2: [_ARTIFACT_2],
}
source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
with open(os.path.join(source_data_dir,
'artifacts.json')) as artifact_json_file:
self._artifacts = json.load(artifact_json_file)

with open(os.path.join(source_data_dir,
'exec_properties.json')) as properties_json_file:
self._properties = json.load(properties_json_file)

def testParseRawArtifactDict(self):
# TODO(b/131417512): Add equal comparison to types.Artifact class so we
# can use asserters.
def _convert_artifact_to_str(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
"""Convert artifact to its string representation."""
result = {}
for k, artifacts in inputs.items():
result[k] = [str(artifact.to_json_dict()) for artifact in artifacts]

return result

self.assertDictEqual(
_convert_artifact_to_str(self._expected_dict),
_convert_artifact_to_str(
ai_platform_entrypoint_utils.parse_raw_artifact_dict(
self._artifacts)))

def testParseExecutionProperties(self):
self.assertDictEqual(
_EXEC_PROPERTIES,
ai_platform_entrypoint_utils.parse_execution_properties(
self._properties))


if __name__ == '__main__':
tf.test.main()
130 changes: 130 additions & 0 deletions tfx/scripts/ai_platform_run_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Lint as: python2, python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entrypoint for invoking TFX components in CAIP managed pipelines."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
from typing import List, Text
import absl

import tensorflow as tf
from google.protobuf import json_format
from tensorflow.python.platform import app # pylint: disable=g-direct-tensorflow-import
from tfx.components.base import base_executor
from tfx.proto.orchestration import execution_result_pb2
from tfx.scripts import ai_platform_entrypoint_utils
from tfx.types import artifact_utils
from tfx.utils import import_utils


def _run_executor(args: argparse.Namespace, beam_args: List[Text]) -> None:
"""Selects a particular executor and run it based on name.
Args:
args:
--executor_class_path: The import path of the executor class.
--json_serialized_metadata: Full JSON-serialized metadata for this
execution. See go/mp-alpha-placeholder for details.
beam_args: Optional parameter that maps to the optional_pipeline_args
parameter in the pipeline, which provides additional configuration options
for apache-beam and tensorflow.logging.
For more about the beam arguments please refer to:
https://cloud.google.com/dataflow/docs/guides/specifying-exec-params
"""
absl.logging.set_verbosity(absl.logging.INFO)

# Rehydrate inputs/outputs/exec_properties from the serialized metadata.
full_metadata_dict = json.loads(args.json_serialized_metadata)

inputs_dict = full_metadata_dict['inputs']
outputs_dict = full_metadata_dict['outputs']
exec_properties_dict = full_metadata_dict['execution_properties']

inputs = ai_platform_entrypoint_utils.parse_raw_artifact_dict(inputs_dict)
outputs = ai_platform_entrypoint_utils.parse_raw_artifact_dict(outputs_dict)
exec_properties = ai_platform_entrypoint_utils.parse_execution_properties(
exec_properties_dict)
absl.logging.info(
'Executor %s do: inputs: %s, outputs: %s, exec_properties: %s' % (
args.executor_class_path, inputs, outputs, exec_properties))
executor_cls = import_utils.import_class_by_path(args.executor_class_path)
executor_context = base_executor.BaseExecutor.Context(
beam_pipeline_args=beam_args, unique_id='')
executor = executor_cls(executor_context)
absl.logging.info('Starting executor')
executor.Do(inputs, outputs, exec_properties)

# Log the output metadata to a file. So that it can be picked up by MP.
metadata_uri = full_metadata_dict['output_metadata_uri']
output_metadata = execution_result_pb2.ExecutorOutput()
for key, output_artifacts in outputs.items():
# Assuming each output is a singleton artifact.
output_metadata.output_dict[key].CopyFrom(
artifact_utils.get_single_instance(output_artifacts).mlmd_artifact)

tf.io.gfile.GFile(metadata_uri,
'wb').write(json_format.MessageToJson(output_metadata))


def main(argv):
"""Parses the arguments for _run_executor() then invokes it.
Args:
argv: Unparsed arguments for run_executor.py. Known argument names include
--executor_class_path: Python class of executor in format of
<module>.<class>.
--json_serialized_metadata: Full JSON-serialized metadata for this
execution. See go/mp-alpha-placeholder for details.
The remaining part of the arguments will be parsed as the beam args used
by each component executors. Some commonly used beam args are as follows:
--runner: The beam pipeline runner environment. Can be DirectRunner (for
running locally) or DataflowRunner (for running on GCP Dataflow
service).
--project: The GCP project ID. Neede when runner==DataflowRunner
--direct_num_workers: Number of threads or subprocesses executing the work
load.
For more about the beam arguments please refer to:
https://cloud.google.com/dataflow/docs/guides/specifying-exec-params
Returns:
None
Raises:
None
"""

parser = argparse.ArgumentParser()
parser.add_argument(
'--executor_class_path',
type=str,
required=True,
help='Python class of executor in format of <module>.<class>.')
parser.add_argument(
'--json_serialized_metadata',
type=str,
required=True,
help='JSON-serialized metadata for this execution. '
'See go/mp-alpha-placeholder for details.')

args, beam_args = parser.parse_known_args(argv)
_run_executor(args, beam_args)


if __name__ == '__main__':
app.run(main=main)
Loading

0 comments on commit a321c4e

Please sign in to comment.