From a321c4ef2f368fdd199dd05cca88e8c5114cdfb2 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Tue, 31 Mar 2020 11:59:52 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 304015056 --- tfx/scripts/ai_platform_entrypoint_utils.py | 79 +++++++++++ .../ai_platform_entrypoint_utils_test.py | 89 ++++++++++++ tfx/scripts/ai_platform_run_executor.py | 130 ++++++++++++++++++ tfx/scripts/ai_platform_run_executor_test.py | 125 +++++++++++++++++ tfx/scripts/testdata/artifacts.json | 20 +++ tfx/scripts/testdata/exec_properties.json | 8 ++ 6 files changed, 451 insertions(+) create mode 100644 tfx/scripts/ai_platform_entrypoint_utils.py create mode 100644 tfx/scripts/ai_platform_entrypoint_utils_test.py create mode 100644 tfx/scripts/ai_platform_run_executor.py create mode 100644 tfx/scripts/ai_platform_run_executor_test.py create mode 100644 tfx/scripts/testdata/artifacts.json create mode 100644 tfx/scripts/testdata/exec_properties.json diff --git a/tfx/scripts/ai_platform_entrypoint_utils.py b/tfx/scripts/ai_platform_entrypoint_utils.py new file mode 100644 index 0000000000..54a773d912 --- /dev/null +++ b/tfx/scripts/ai_platform_entrypoint_utils.py @@ -0,0 +1,79 @@ +# Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions used in mp_run_executor.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +from typing import Any, Dict, List, Text +from google.protobuf import json_format +from ml_metadata.proto import metadata_store_pb2 +from tfx.types import artifact +from tfx.types import artifact_utils + + +def parse_raw_artifact_dict( + inputs_dict: Dict[Text, Any]) -> Dict[Text, List[artifact.Artifact]]: + """Parses a dict from key to a list of a single Artifact from a nested dict.""" + result = {} + for k, v in inputs_dict.items(): + result[k] = [ + _parse_raw_artifact(single_artifact) + for single_artifact in v['artifacts'] + ] + return result + + +def parse_execution_properties(dict_data: Dict[Text, Any]) -> Dict[Text, Any]: + """Parses a dict from key to Value proto as execution properties.""" + result = {} + for k, v in dict_data.items(): + # Translate each field from Value pb to plain value. + value_pb = metadata_store_pb2.Value() + json_format.Parse(json.dumps(v), value_pb) + result[k] = getattr(value_pb, value_pb.WhichOneof('value')) + if result[k] is None: + raise TypeError('Unrecognized type encountered at field %s of execution' + ' properties %s' % (k, dict_data)) + + return result + + +def _parse_raw_artifact(dict_data: Dict[Text, Any]) -> artifact.Artifact: + """Parses json serialized version of artifact without artifact_type.""" + # This parser can only reserve what's inside artifact pb message. + artifact_pb = metadata_store_pb2.Artifact() + # TODO(b/152444458): For compatibility, current TFX serialization assumes + # there is no type field in Artifact pb message. + type_name = dict_data.pop('type') + json_format.Parse(json.dumps(dict_data), artifact_pb) + + # Make an ArtifactType pb according to artifact_pb + type_pb = metadata_store_pb2.ArtifactType() + type_pb.name = type_name + for k, v in artifact_pb.properties.items(): + if v.HasField('int_value'): + type_pb.properties[k] = metadata_store_pb2.PropertyType.INT + elif v.HasField('string_value'): + type_pb.properties[k] = metadata_store_pb2.PropertyType.STRING + elif v.HasField('double_value'): + type_pb.properties[k] = metadata_store_pb2.PropertyType.DOUBLE + else: + raise ValueError('Unrecognized type encountered at field %s' % (k)) + + result = artifact_utils.deserialize_artifact(type_pb, artifact_pb) + return result diff --git a/tfx/scripts/ai_platform_entrypoint_utils_test.py b/tfx/scripts/ai_platform_entrypoint_utils_test.py new file mode 100644 index 0000000000..32ea6f6ebb --- /dev/null +++ b/tfx/scripts/ai_platform_entrypoint_utils_test.py @@ -0,0 +1,89 @@ +# Lint as: python2, python3 +# Copyright 2020 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tfx.scripts.entrypoint_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +from typing import Any, Dict, Text +import tensorflow as tf + +from tfx.scripts import ai_platform_entrypoint_utils +from tfx.types import standard_artifacts + +_ARTIFACT_1 = standard_artifacts.StringType() +_KEY_1 = 'input_1' + +_ARTIFACT_2 = standard_artifacts.ModelBlessing() +_KEY_2 = 'input_2' + +_EXEC_PROPERTIES = { + 'input_config': 'input config string', + 'output_config': + '{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": ' + '\"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }', +} + + +class EntrypointUtilsTest(tf.test.TestCase): + + def setUp(self): + super(EntrypointUtilsTest, self).setUp() + _ARTIFACT_1.type_id = 1 + _ARTIFACT_1.uri = 'gs://root/string/' + _ARTIFACT_2.type_id = 2 + _ARTIFACT_2.uri = 'gs://root/model/' + self._expected_dict = { + _KEY_1: [_ARTIFACT_1], + _KEY_2: [_ARTIFACT_2], + } + source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata') + with open(os.path.join(source_data_dir, + 'artifacts.json')) as artifact_json_file: + self._artifacts = json.load(artifact_json_file) + + with open(os.path.join(source_data_dir, + 'exec_properties.json')) as properties_json_file: + self._properties = json.load(properties_json_file) + + def testParseRawArtifactDict(self): + # TODO(b/131417512): Add equal comparison to types.Artifact class so we + # can use asserters. + def _convert_artifact_to_str(inputs: Dict[Text, Any]) -> Dict[Text, Any]: + """Convert artifact to its string representation.""" + result = {} + for k, artifacts in inputs.items(): + result[k] = [str(artifact.to_json_dict()) for artifact in artifacts] + + return result + + self.assertDictEqual( + _convert_artifact_to_str(self._expected_dict), + _convert_artifact_to_str( + ai_platform_entrypoint_utils.parse_raw_artifact_dict( + self._artifacts))) + + def testParseExecutionProperties(self): + self.assertDictEqual( + _EXEC_PROPERTIES, + ai_platform_entrypoint_utils.parse_execution_properties( + self._properties)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tfx/scripts/ai_platform_run_executor.py b/tfx/scripts/ai_platform_run_executor.py new file mode 100644 index 0000000000..a4de12148d --- /dev/null +++ b/tfx/scripts/ai_platform_run_executor.py @@ -0,0 +1,130 @@ +# Lint as: python2, python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Entrypoint for invoking TFX components in CAIP managed pipelines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +from typing import List, Text +import absl + +import tensorflow as tf +from google.protobuf import json_format +from tensorflow.python.platform import app # pylint: disable=g-direct-tensorflow-import +from tfx.components.base import base_executor +from tfx.proto.orchestration import execution_result_pb2 +from tfx.scripts import ai_platform_entrypoint_utils +from tfx.types import artifact_utils +from tfx.utils import import_utils + + +def _run_executor(args: argparse.Namespace, beam_args: List[Text]) -> None: + """Selects a particular executor and run it based on name. + + Args: + args: + --executor_class_path: The import path of the executor class. + --json_serialized_metadata: Full JSON-serialized metadata for this + execution. See go/mp-alpha-placeholder for details. + beam_args: Optional parameter that maps to the optional_pipeline_args + parameter in the pipeline, which provides additional configuration options + for apache-beam and tensorflow.logging. + For more about the beam arguments please refer to: + https://cloud.google.com/dataflow/docs/guides/specifying-exec-params + """ + absl.logging.set_verbosity(absl.logging.INFO) + + # Rehydrate inputs/outputs/exec_properties from the serialized metadata. + full_metadata_dict = json.loads(args.json_serialized_metadata) + + inputs_dict = full_metadata_dict['inputs'] + outputs_dict = full_metadata_dict['outputs'] + exec_properties_dict = full_metadata_dict['execution_properties'] + + inputs = ai_platform_entrypoint_utils.parse_raw_artifact_dict(inputs_dict) + outputs = ai_platform_entrypoint_utils.parse_raw_artifact_dict(outputs_dict) + exec_properties = ai_platform_entrypoint_utils.parse_execution_properties( + exec_properties_dict) + absl.logging.info( + 'Executor %s do: inputs: %s, outputs: %s, exec_properties: %s' % ( + args.executor_class_path, inputs, outputs, exec_properties)) + executor_cls = import_utils.import_class_by_path(args.executor_class_path) + executor_context = base_executor.BaseExecutor.Context( + beam_pipeline_args=beam_args, unique_id='') + executor = executor_cls(executor_context) + absl.logging.info('Starting executor') + executor.Do(inputs, outputs, exec_properties) + + # Log the output metadata to a file. So that it can be picked up by MP. + metadata_uri = full_metadata_dict['output_metadata_uri'] + output_metadata = execution_result_pb2.ExecutorOutput() + for key, output_artifacts in outputs.items(): + # Assuming each output is a singleton artifact. + output_metadata.output_dict[key].CopyFrom( + artifact_utils.get_single_instance(output_artifacts).mlmd_artifact) + + tf.io.gfile.GFile(metadata_uri, + 'wb').write(json_format.MessageToJson(output_metadata)) + + +def main(argv): + """Parses the arguments for _run_executor() then invokes it. + + Args: + argv: Unparsed arguments for run_executor.py. Known argument names include + --executor_class_path: Python class of executor in format of + .. + --json_serialized_metadata: Full JSON-serialized metadata for this + execution. See go/mp-alpha-placeholder for details. + The remaining part of the arguments will be parsed as the beam args used + by each component executors. Some commonly used beam args are as follows: + --runner: The beam pipeline runner environment. Can be DirectRunner (for + running locally) or DataflowRunner (for running on GCP Dataflow + service). + --project: The GCP project ID. Neede when runner==DataflowRunner + --direct_num_workers: Number of threads or subprocesses executing the work + load. + For more about the beam arguments please refer to: + https://cloud.google.com/dataflow/docs/guides/specifying-exec-params + + Returns: + None + + Raises: + None + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + '--executor_class_path', + type=str, + required=True, + help='Python class of executor in format of ..') + parser.add_argument( + '--json_serialized_metadata', + type=str, + required=True, + help='JSON-serialized metadata for this execution. ' + 'See go/mp-alpha-placeholder for details.') + + args, beam_args = parser.parse_known_args(argv) + _run_executor(args, beam_args) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tfx/scripts/ai_platform_run_executor_test.py b/tfx/scripts/ai_platform_run_executor_test.py new file mode 100644 index 0000000000..e2126c9763 --- /dev/null +++ b/tfx/scripts/ai_platform_run_executor_test.py @@ -0,0 +1,125 @@ +# Lint as: python2, python3 +# Copyright 2020 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tfx.scripts.run_executor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +from typing import Any, Dict, List, Text +import mock +import tensorflow as tf + +from tfx.components.base import base_executor +from tfx.scripts import ai_platform_run_executor +from tfx.types import artifact + + +class _ArgsCapture(object): + instance = None + + def __enter__(self): + _ArgsCapture.instance = self + return self + + def __exit__(self, exception_type, exception_value, traceback): + _ArgsCapture.instance = None + + +class _FakeExecutor(base_executor.BaseExecutor): + + def Do(self, input_dict: Dict[Text, List[artifact.Artifact]], + output_dict: Dict[Text, List[artifact.Artifact]], + exec_properties: Dict[Text, Any]) -> None: + """Overrides BaseExecutor.Do().""" + args_capture = _ArgsCapture.instance + args_capture.input_dict = input_dict + args_capture.output_dict = output_dict + args_capture.exec_properties = exec_properties + + +_INPUT_DICT = { + "input_1": { + "artifacts": [{ + "uri": "gs://root/input_1/", + "typeId": 1, + "type": "ExternalArtifact", + }] + }, + "input_2": { + "artifacts": [{ + "uri": "gs://root/input_2/", + "typeId": 1, + "type": "ExternalArtifact" + }] + }, +} + +_OUTPUT_DICT = { + "output": { + "artifacts": [{ + "uri": "gs://root/output/", + "typeId": 2, + "type": "Examples" + }] + } +} + +_EXEC_PROPERTIES_PB = { + "key_1": { + "stringValue": "value_1" + }, + "key_2": { + "intValue": 42 + }, +} + +_EXEC_PROPERTIES = {"key_1": "value_1", "key_2": 42} + +_JSON_SERIALIZED_METADATA = { + "inputs": _INPUT_DICT, + "outputs": _OUTPUT_DICT, + "execution_properties": _EXEC_PROPERTIES_PB, + "output_metadata_uri": "gs://root/output_metadata/" +} + + +class RunExecutorTest(tf.test.TestCase): + + @mock.patch.object( + tf.io.gfile.GFile, "write", return_value=True, autospec=True) + def testEntryPoint(self, fake_write_fn): + """Test the entrypoint with toy inputs.""" + with _ArgsCapture() as args_capture: + args = [ + "--executor_class_path", + "%s.%s" % + (_FakeExecutor.__module__, _FakeExecutor.__name__), + "--json_serialized_metadata", + json.dumps(_JSON_SERIALIZED_METADATA) + ] + ai_platform_run_executor.main(args) + # TODO(b/131417512): Add equal comparison to types.Artifact class so we + # can use asserters. + self.assertSetEqual( + set(args_capture.input_dict.keys()), set(_INPUT_DICT.keys())) + self.assertSetEqual( + set(args_capture.output_dict.keys()), set(_OUTPUT_DICT.keys())) + self.assertDictEqual(args_capture.exec_properties, _EXEC_PROPERTIES) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tfx/scripts/testdata/artifacts.json b/tfx/scripts/testdata/artifacts.json new file mode 100644 index 0000000000..2b17d3f9d5 --- /dev/null +++ b/tfx/scripts/testdata/artifacts.json @@ -0,0 +1,20 @@ +{ + "input_1": { + "artifacts": [ + { + "uri": "gs://root/string/", + "typeId": 1, + "type": "StringType" + } + ] + }, + "input_2": { + "artifacts": [ + { + "uri": "gs://root/model/", + "typeId": 2, + "type": "ModelBlessing" + } + ] + } +} diff --git a/tfx/scripts/testdata/exec_properties.json b/tfx/scripts/testdata/exec_properties.json new file mode 100644 index 0000000000..dd8752a02b --- /dev/null +++ b/tfx/scripts/testdata/exec_properties.json @@ -0,0 +1,8 @@ +{ + "input_config": { + "string_value": "input config string" + }, + "output_config": { + "string_value": "{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": \"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }" + } +}