-
Notifications
You must be signed in to change notification settings - Fork 707
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
451 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Lint as: python3 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Utility functions used in mp_run_executor.py.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import json | ||
from typing import Any, Dict, List, Text | ||
from google.protobuf import json_format | ||
from ml_metadata.proto import metadata_store_pb2 | ||
from tfx.types import artifact | ||
from tfx.types import artifact_utils | ||
|
||
|
||
def parse_raw_artifact_dict( | ||
inputs_dict: Dict[Text, Any]) -> Dict[Text, List[artifact.Artifact]]: | ||
"""Parses a dict from key to a list of a single Artifact from a nested dict.""" | ||
result = {} | ||
for k, v in inputs_dict.items(): | ||
result[k] = [ | ||
_parse_raw_artifact(single_artifact) | ||
for single_artifact in v['artifacts'] | ||
] | ||
return result | ||
|
||
|
||
def parse_execution_properties(dict_data: Dict[Text, Any]) -> Dict[Text, Any]: | ||
"""Parses a dict from key to Value proto as execution properties.""" | ||
result = {} | ||
for k, v in dict_data.items(): | ||
# Translate each field from Value pb to plain value. | ||
value_pb = metadata_store_pb2.Value() | ||
json_format.Parse(json.dumps(v), value_pb) | ||
result[k] = getattr(value_pb, value_pb.WhichOneof('value')) | ||
if result[k] is None: | ||
raise TypeError('Unrecognized type encountered at field %s of execution' | ||
' properties %s' % (k, dict_data)) | ||
|
||
return result | ||
|
||
|
||
def _parse_raw_artifact(dict_data: Dict[Text, Any]) -> artifact.Artifact: | ||
"""Parses json serialized version of artifact without artifact_type.""" | ||
# This parser can only reserve what's inside artifact pb message. | ||
artifact_pb = metadata_store_pb2.Artifact() | ||
# TODO(b/152444458): For compatibility, current TFX serialization assumes | ||
# there is no type field in Artifact pb message. | ||
type_name = dict_data.pop('type') | ||
json_format.Parse(json.dumps(dict_data), artifact_pb) | ||
|
||
# Make an ArtifactType pb according to artifact_pb | ||
type_pb = metadata_store_pb2.ArtifactType() | ||
type_pb.name = type_name | ||
for k, v in artifact_pb.properties.items(): | ||
if v.HasField('int_value'): | ||
type_pb.properties[k] = metadata_store_pb2.PropertyType.INT | ||
elif v.HasField('string_value'): | ||
type_pb.properties[k] = metadata_store_pb2.PropertyType.STRING | ||
elif v.HasField('double_value'): | ||
type_pb.properties[k] = metadata_store_pb2.PropertyType.DOUBLE | ||
else: | ||
raise ValueError('Unrecognized type encountered at field %s' % (k)) | ||
|
||
result = artifact_utils.deserialize_artifact(type_pb, artifact_pb) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Lint as: python2, python3 | ||
# Copyright 2020 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for tfx.scripts.entrypoint_utils.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import json | ||
import os | ||
from typing import Any, Dict, Text | ||
import tensorflow as tf | ||
|
||
from tfx.scripts import ai_platform_entrypoint_utils | ||
from tfx.types import standard_artifacts | ||
|
||
_ARTIFACT_1 = standard_artifacts.StringType() | ||
_KEY_1 = 'input_1' | ||
|
||
_ARTIFACT_2 = standard_artifacts.ModelBlessing() | ||
_KEY_2 = 'input_2' | ||
|
||
_EXEC_PROPERTIES = { | ||
'input_config': 'input config string', | ||
'output_config': | ||
'{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": ' | ||
'\"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }', | ||
} | ||
|
||
|
||
class EntrypointUtilsTest(tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super(EntrypointUtilsTest, self).setUp() | ||
_ARTIFACT_1.type_id = 1 | ||
_ARTIFACT_1.uri = 'gs://root/string/' | ||
_ARTIFACT_2.type_id = 2 | ||
_ARTIFACT_2.uri = 'gs://root/model/' | ||
self._expected_dict = { | ||
_KEY_1: [_ARTIFACT_1], | ||
_KEY_2: [_ARTIFACT_2], | ||
} | ||
source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata') | ||
with open(os.path.join(source_data_dir, | ||
'artifacts.json')) as artifact_json_file: | ||
self._artifacts = json.load(artifact_json_file) | ||
|
||
with open(os.path.join(source_data_dir, | ||
'exec_properties.json')) as properties_json_file: | ||
self._properties = json.load(properties_json_file) | ||
|
||
def testParseRawArtifactDict(self): | ||
# TODO(b/131417512): Add equal comparison to types.Artifact class so we | ||
# can use asserters. | ||
def _convert_artifact_to_str(inputs: Dict[Text, Any]) -> Dict[Text, Any]: | ||
"""Convert artifact to its string representation.""" | ||
result = {} | ||
for k, artifacts in inputs.items(): | ||
result[k] = [str(artifact.to_json_dict()) for artifact in artifacts] | ||
|
||
return result | ||
|
||
self.assertDictEqual( | ||
_convert_artifact_to_str(self._expected_dict), | ||
_convert_artifact_to_str( | ||
ai_platform_entrypoint_utils.parse_raw_artifact_dict( | ||
self._artifacts))) | ||
|
||
def testParseExecutionProperties(self): | ||
self.assertDictEqual( | ||
_EXEC_PROPERTIES, | ||
ai_platform_entrypoint_utils.parse_execution_properties( | ||
self._properties)) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Lint as: python2, python3 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Entrypoint for invoking TFX components in CAIP managed pipelines.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import json | ||
from typing import List, Text | ||
import absl | ||
|
||
import tensorflow as tf | ||
from google.protobuf import json_format | ||
from tensorflow.python.platform import app # pylint: disable=g-direct-tensorflow-import | ||
from tfx.components.base import base_executor | ||
from tfx.proto.orchestration import execution_result_pb2 | ||
from tfx.scripts import ai_platform_entrypoint_utils | ||
from tfx.types import artifact_utils | ||
from tfx.utils import import_utils | ||
|
||
|
||
def _run_executor(args: argparse.Namespace, beam_args: List[Text]) -> None: | ||
"""Selects a particular executor and run it based on name. | ||
Args: | ||
args: | ||
--executor_class_path: The import path of the executor class. | ||
--json_serialized_metadata: Full JSON-serialized metadata for this | ||
execution. See go/mp-alpha-placeholder for details. | ||
beam_args: Optional parameter that maps to the optional_pipeline_args | ||
parameter in the pipeline, which provides additional configuration options | ||
for apache-beam and tensorflow.logging. | ||
For more about the beam arguments please refer to: | ||
https://cloud.google.com/dataflow/docs/guides/specifying-exec-params | ||
""" | ||
absl.logging.set_verbosity(absl.logging.INFO) | ||
|
||
# Rehydrate inputs/outputs/exec_properties from the serialized metadata. | ||
full_metadata_dict = json.loads(args.json_serialized_metadata) | ||
|
||
inputs_dict = full_metadata_dict['inputs'] | ||
outputs_dict = full_metadata_dict['outputs'] | ||
exec_properties_dict = full_metadata_dict['execution_properties'] | ||
|
||
inputs = ai_platform_entrypoint_utils.parse_raw_artifact_dict(inputs_dict) | ||
outputs = ai_platform_entrypoint_utils.parse_raw_artifact_dict(outputs_dict) | ||
exec_properties = ai_platform_entrypoint_utils.parse_execution_properties( | ||
exec_properties_dict) | ||
absl.logging.info( | ||
'Executor %s do: inputs: %s, outputs: %s, exec_properties: %s' % ( | ||
args.executor_class_path, inputs, outputs, exec_properties)) | ||
executor_cls = import_utils.import_class_by_path(args.executor_class_path) | ||
executor_context = base_executor.BaseExecutor.Context( | ||
beam_pipeline_args=beam_args, unique_id='') | ||
executor = executor_cls(executor_context) | ||
absl.logging.info('Starting executor') | ||
executor.Do(inputs, outputs, exec_properties) | ||
|
||
# Log the output metadata to a file. So that it can be picked up by MP. | ||
metadata_uri = full_metadata_dict['output_metadata_uri'] | ||
output_metadata = execution_result_pb2.ExecutorOutput() | ||
for key, output_artifacts in outputs.items(): | ||
# Assuming each output is a singleton artifact. | ||
output_metadata.output_dict[key].CopyFrom( | ||
artifact_utils.get_single_instance(output_artifacts).mlmd_artifact) | ||
|
||
tf.io.gfile.GFile(metadata_uri, | ||
'wb').write(json_format.MessageToJson(output_metadata)) | ||
|
||
|
||
def main(argv): | ||
"""Parses the arguments for _run_executor() then invokes it. | ||
Args: | ||
argv: Unparsed arguments for run_executor.py. Known argument names include | ||
--executor_class_path: Python class of executor in format of | ||
<module>.<class>. | ||
--json_serialized_metadata: Full JSON-serialized metadata for this | ||
execution. See go/mp-alpha-placeholder for details. | ||
The remaining part of the arguments will be parsed as the beam args used | ||
by each component executors. Some commonly used beam args are as follows: | ||
--runner: The beam pipeline runner environment. Can be DirectRunner (for | ||
running locally) or DataflowRunner (for running on GCP Dataflow | ||
service). | ||
--project: The GCP project ID. Neede when runner==DataflowRunner | ||
--direct_num_workers: Number of threads or subprocesses executing the work | ||
load. | ||
For more about the beam arguments please refer to: | ||
https://cloud.google.com/dataflow/docs/guides/specifying-exec-params | ||
Returns: | ||
None | ||
Raises: | ||
None | ||
""" | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--executor_class_path', | ||
type=str, | ||
required=True, | ||
help='Python class of executor in format of <module>.<class>.') | ||
parser.add_argument( | ||
'--json_serialized_metadata', | ||
type=str, | ||
required=True, | ||
help='JSON-serialized metadata for this execution. ' | ||
'See go/mp-alpha-placeholder for details.') | ||
|
||
args, beam_args = parser.parse_known_args(argv) | ||
_run_executor(args, beam_args) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main=main) |
Oops, something went wrong.