From c01315a89da6a0eee4a12d3bb1c095c28931aee3 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 22 Aug 2019 15:31:24 -0700 Subject: [PATCH] SDK - Refactoring - Replaced the TypeMeta class (#1930) * SDK - Refactoring - Replaced the TypeMeta class The PipelineParam no longer exposes the private TypeMeta class Fixes #1420 The refactoring PR is part of a series of PR which unifies the metadata and specification types. --- sdk/python/kfp/compiler/compiler.py | 4 +- sdk/python/kfp/components/_components.py | 4 +- sdk/python/kfp/components/_dsl_bridge.py | 6 +- sdk/python/kfp/dsl/_component.py | 12 ++-- sdk/python/kfp/dsl/_metadata.py | 71 +++++--------------- sdk/python/kfp/dsl/_pipeline_param.py | 7 +- sdk/python/kfp/dsl/types.py | 3 + sdk/python/tests/dsl/component_tests.py | 10 +-- sdk/python/tests/dsl/metadata_tests.py | 59 ++++------------ sdk/python/tests/dsl/pipeline_param_tests.py | 7 +- sdk/python/tests/dsl/pipeline_tests.py | 6 +- 11 files changed, 58 insertions(+), 131 deletions(-) diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index 92d8658dcd6..810a7bf75f0 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -25,7 +25,7 @@ from ._op_to_template import _op_to_template from ._default_transformers import add_pod_env -from ..dsl._metadata import TypeMeta, _extract_pipeline_metadata +from ..dsl._metadata import _extract_pipeline_metadata from ..dsl._ops_group import OpsGroup class Compiler(object): @@ -596,7 +596,7 @@ def _compile(self, pipeline_func): args_list = [] for arg_name in argspec.args: - arg_type = TypeMeta() + arg_type = None for input in pipeline_meta.inputs: if arg_name == input.name: arg_type = input.param_type diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index 04338a688ae..331c94f6267 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -235,8 +235,8 @@ def create_task_from_component_and_arguments(pythonic_arguments): if kfp.TYPE_CHECK: for input_spec in component_spec.inputs: if input_spec.name == key: - if arguments[key].param_type is not None and not check_types(arguments[key].param_type.to_dict_or_str(), '' if input_spec.type is None else input_spec.type): - raise InconsistentTypeException('Component "' + name + '" is expecting ' + key + ' to be type(' + str(input_spec.type) + '), but the passed argument is type(' + arguments[key].param_type.serialize() + ')') + if arguments[key].param_type is not None and not check_types(arguments[key].param_type, '' if input_spec.type is None else input_spec.type): + raise InconsistentTypeException('Component "' + name + '" is expecting ' + key + ' to be type(' + str(input_spec.type) + '), but the passed argument is type(' + str(arguments[key].param_type) + ')') arguments[key] = str(arguments[key]) task = TaskSpec( diff --git a/sdk/python/kfp/components/_dsl_bridge.py b/sdk/python/kfp/components/_dsl_bridge.py index 32388a094c4..750789f95da 100644 --- a/sdk/python/kfp/components/_dsl_bridge.py +++ b/sdk/python/kfp/components/_dsl_bridge.py @@ -16,7 +16,7 @@ from typing import Mapping from ._structures import ContainerImplementation, ConcatPlaceholder, IfPlaceholder, InputValuePlaceholder, InputPathPlaceholder, IsPresentPlaceholder, OutputPathPlaceholder, TaskSpec from ._components import _generate_output_file_name, _default_component_name -from kfp.dsl._metadata import ComponentMeta, ParameterMeta, TypeMeta, _annotation_to_typemeta +from kfp.dsl._metadata import ComponentMeta, ParameterMeta def create_container_op_from_task(task_spec: TaskSpec): argument_values = task_spec.arguments @@ -143,10 +143,10 @@ def _create_container_op_from_resolved_task(name:str, container_image:str, comma # Inputs if component_spec.inputs is not None: for input in component_spec.inputs: - component_meta.inputs.append(ParameterMeta(name=input.name, description=input.description, param_type=_annotation_to_typemeta(input.type), default=input.default)) + component_meta.inputs.append(ParameterMeta(name=input.name, description=input.description, param_type=input.type, default=input.default)) if component_spec.outputs is not None: for output in component_spec.outputs: - component_meta.outputs.append(ParameterMeta(name=output.name, description=output.description, param_type=_annotation_to_typemeta(output.type))) + component_meta.outputs.append(ParameterMeta(name=output.name, description=output.description, param_type=output.type)) task = dsl.ContainerOp( name=name, diff --git a/sdk/python/kfp/dsl/_component.py b/sdk/python/kfp/dsl/_component.py index 917077b0f7e..aa618d030e1 100644 --- a/sdk/python/kfp/dsl/_component.py +++ b/sdk/python/kfp/dsl/_component.py @@ -71,19 +71,19 @@ def _component(*args, **kargs): if kfp.TYPE_CHECK: arg_index = 0 for arg in args: - if isinstance(arg, PipelineParam) and not check_types(arg.param_type.to_dict_or_str(), component_meta.inputs[arg_index].param_type.to_dict_or_str()): + if isinstance(arg, PipelineParam) and not check_types(arg.param_type, component_meta.inputs[arg_index].param_type): raise InconsistentTypeException('Component "' + component_meta.name + '" is expecting ' + component_meta.inputs[arg_index].name + - ' to be type(' + component_meta.inputs[arg_index].param_type.serialize() + - '), but the passed argument is type(' + arg.param_type.serialize() + ')') + ' to be type(' + str(component_meta.inputs[arg_index].param_type) + + '), but the passed argument is type(' + str(arg.param_type) + ')') arg_index += 1 if kargs is not None: for key in kargs: if isinstance(kargs[key], PipelineParam): for input_spec in component_meta.inputs: - if input_spec.name == key and not check_types(kargs[key].param_type.to_dict_or_str(), input_spec.param_type.to_dict_or_str()): + if input_spec.name == key and not check_types(kargs[key].param_type, input_spec.param_type): raise InconsistentTypeException('Component "' + component_meta.name + '" is expecting ' + input_spec.name + - ' to be type(' + input_spec.param_type.serialize() + - '), but the passed argument is type(' + kargs[key].param_type.serialize() + ')') + ' to be type(' + str(input_spec.param_type) + + '), but the passed argument is type(' + str(kargs[key].param_type) + ')') container_op = func(*args, **kargs) container_op._set_metadata(component_meta) diff --git a/sdk/python/kfp/dsl/_metadata.py b/sdk/python/kfp/dsl/_metadata.py index c4836d7ddf7..c798b534831 100644 --- a/sdk/python/kfp/dsl/_metadata.py +++ b/sdk/python/kfp/dsl/_metadata.py @@ -32,64 +32,22 @@ def serialize(self): def __eq__(self, other): return self.__dict__ == other.__dict__ -class TypeMeta(BaseMeta): - def __init__(self, - name: str = '', - properties: Dict = None): - self.name = name - self.properties = {} if properties is None else properties - - def to_dict_or_str(self): - if self.properties is None or len(self.properties) == 0: - return self.name - else: - return {self.name: self.properties} - - @staticmethod - def from_dict_or_str(payload): - '''from_dict_or_str accepts a payload object and returns a TypeMeta instance - Args: - payload (str/dict): the payload could be a str or a dict - ''' - - type_meta = TypeMeta() - if isinstance(payload, dict): - if not _check_valid_type_dict(payload): - raise ValueError(payload + ' is not a valid type string') - type_meta.name, type_meta.properties = list(payload.items())[0] - # Convert possible OrderedDict to dict - type_meta.properties = dict(type_meta.properties) - elif isinstance(payload, str): - type_meta.name = payload - else: - raise ValueError('from_dict_or_str is expecting either dict or str.') - return type_meta - - def serialize(self): - return str(self.to_dict_or_str()) - - @staticmethod - def deserialize(payload): - '''deserialize expects two types of input: dict and str - 1) If the payload is a string, the type is named as such with no properties. - 2) If the payload is a dict, the type name and properties are extracted. ''' - return TypeMeta.from_dict_or_str(payload) class ParameterMeta(BaseMeta): def __init__(self, name: str, description: str = '', - param_type: TypeMeta = None, + param_type = None, default = None): self.name = name self.description = description - self.param_type = TypeMeta() if param_type is None else param_type + self.param_type = param_type self.default = default def to_dict(self): return {'name': self.name, 'description': self.description, - 'type': self.param_type.to_dict_or_str(), + 'type': self.param_type or '', 'default': self.default} class ComponentMeta(BaseMeta): @@ -132,25 +90,25 @@ def to_dict(self): } def _annotation_to_typemeta(annotation): - '''_annotation_to_type_meta converts an annotation to an instance of TypeMeta + '''_annotation_to_type_meta converts an annotation to a type structure Args: annotation(BaseType/str/dict): input/output annotations BaseType: registered in kfp.dsl.types str: either a string of a dict serialization or a string of the type name dict: type name and properties. note that the properties values can be dict. Returns: - TypeMeta + dict or string representing the type ''' if isinstance(annotation, BaseType): - arg_type = TypeMeta.deserialize(_instance_to_dict(annotation)) + arg_type = _instance_to_dict(annotation) elif isinstance(annotation, str): - arg_type = TypeMeta.deserialize(annotation) + arg_type = annotation elif isinstance(annotation, dict): if not _check_valid_type_dict(annotation): raise ValueError('Annotation ' + str(annotation) + ' is not a valid type dictionary.') - arg_type = TypeMeta.deserialize(annotation) + arg_type = annotation else: - return TypeMeta() + return None return arg_type @@ -174,7 +132,7 @@ def _extract_component_metadata(func): # Inputs inputs = [] for arg in fullargspec.args: - arg_type = TypeMeta() + arg_type = None arg_default = arg_defaults[arg] if arg in arg_defaults else None if isinstance(arg_default, PipelineParam): arg_default = arg_default.value @@ -227,19 +185,20 @@ def _extract_pipeline_metadata(func): ) # Inputs for arg in args: - arg_type = TypeMeta() + arg_type = None arg_default = arg_defaults[arg] if arg in arg_defaults else None if isinstance(arg_default, PipelineParam): arg_default = arg_default.value if arg in annotations: arg_type = _annotation_to_typemeta(annotations[arg]) - if 'openapi_schema_validator' in arg_type.properties and arg_default is not None: + arg_type_properties = list(arg_type.values())[0] if isinstance(arg_type, dict) else {} + if 'openapi_schema_validator' in arg_type_properties and arg_default is not None: from jsonschema import validate import json - schema_object = arg_type.properties['openapi_schema_validator'] + schema_object = arg_type_properties['openapi_schema_validator'] if isinstance(schema_object, str): # In case the property value for the schema validator is a string instead of a dict. - schema_object = json.loads(arg_type.properties['openapi_schema_validator']) + schema_object = json.loads(schema_object) validate(instance=arg_default, schema=schema_object) pipeline_meta.inputs.append(ParameterMeta(name=arg, description='', param_type=arg_type, default=arg_default)) diff --git a/sdk/python/kfp/dsl/_pipeline_param.py b/sdk/python/kfp/dsl/_pipeline_param.py index 3d4875e8689..3ab69fe2521 100644 --- a/sdk/python/kfp/dsl/_pipeline_param.py +++ b/sdk/python/kfp/dsl/_pipeline_param.py @@ -15,8 +15,7 @@ import re from collections import namedtuple -from typing import List -from ._metadata import TypeMeta +from typing import Dict, List, Union # TODO: Move this to a separate class @@ -136,7 +135,7 @@ class PipelineParam(object): value passed between components. """ - def __init__(self, name: str, op_name: str=None, value: str=None, param_type: TypeMeta=TypeMeta(), pattern: str=None): + def __init__(self, name: str, op_name: str=None, value: str=None, param_type : Union[str, Dict] = None, pattern: str=None): """Create a new instance of PipelineParam. Args: name: name of the pipeline parameter. @@ -218,6 +217,6 @@ def __hash__(self): def ignore_type(self): """ignore_type ignores the type information such that type checking would also pass""" - self.param_type = TypeMeta() + self.param_type = None return self diff --git a/sdk/python/kfp/dsl/types.py b/sdk/python/kfp/dsl/types.py index 63ce4ca25fe..16c5eb4db3a 100644 --- a/sdk/python/kfp/dsl/types.py +++ b/sdk/python/kfp/dsl/types.py @@ -145,6 +145,9 @@ def _check_dict_types(checked_type, expected_type): checked_type (dict): A dict that describes a type from the upstream component output expected_type (dict): A dict that describes a type from the downstream component input ''' + if not checked_type or not expected_type: + # If the type is empty, it matches any types + return True checked_type_name,_ = list(checked_type.items())[0] expected_type_name,_ = list(expected_type.items())[0] if checked_type_name == '' or expected_type_name == '': diff --git a/sdk/python/tests/dsl/component_tests.py b/sdk/python/tests/dsl/component_tests.py index 73599c2faf8..46c7f01dcf2 100644 --- a/sdk/python/tests/dsl/component_tests.py +++ b/sdk/python/tests/dsl/component_tests.py @@ -15,7 +15,7 @@ import kfp import kfp.dsl as dsl from kfp.dsl import component, graph_component -from kfp.dsl._metadata import ComponentMeta, ParameterMeta, TypeMeta +from kfp.dsl._metadata import ComponentMeta, ParameterMeta from kfp.dsl.types import Integer, GCSPath, InconsistentTypeException from kfp.dsl import ContainerOp, Pipeline, PipelineParam import unittest @@ -36,10 +36,10 @@ def componentA(a: {'ArtifactA': {'file_type': 'csv'}}, b: Integer() = 12, c: {'A containerOp = componentA(1,2,c=3) golden_meta = ComponentMeta(name='componentA', description='') - golden_meta.inputs.append(ParameterMeta(name='a', description='', param_type=TypeMeta(name='ArtifactA', properties={'file_type': 'csv'}))) - golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type=TypeMeta(name='Integer', properties={'openapi_schema_validator': {"type": "integer"}}), default=12)) - golden_meta.inputs.append(ParameterMeta(name='c', description='', param_type=TypeMeta(name='ArtifactB', properties={'path_type':'file', 'file_type': 'tsv'}), default='gs://hello/world')) - golden_meta.outputs.append(ParameterMeta(name='model', description='', param_type=TypeMeta(name='Integer', properties={'openapi_schema_validator': {"type": "integer"}}))) + golden_meta.inputs.append(ParameterMeta(name='a', description='', param_type={'ArtifactA': {'file_type': 'csv'}})) + golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type={'Integer': {'openapi_schema_validator': {"type": "integer"}}}, default=12)) + golden_meta.inputs.append(ParameterMeta(name='c', description='', param_type={'ArtifactB': {'path_type':'file', 'file_type': 'tsv'}}, default='gs://hello/world')) + golden_meta.outputs.append(ParameterMeta(name='model', description='', param_type={'Integer': {'openapi_schema_validator': {"type": "integer"}}})) self.assertEqual(containerOp._metadata, golden_meta) diff --git a/sdk/python/tests/dsl/metadata_tests.py b/sdk/python/tests/dsl/metadata_tests.py index 52c5c161471..99bf74f687d 100644 --- a/sdk/python/tests/dsl/metadata_tests.py +++ b/sdk/python/tests/dsl/metadata_tests.py @@ -12,39 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from kfp.dsl._metadata import ComponentMeta, ParameterMeta, TypeMeta +from kfp.dsl._metadata import ComponentMeta, ParameterMeta import unittest -class TestTypeMeta(unittest.TestCase): - def test_deserialize(self): - component_dict = { - 'GCSPath': { - 'bucket_type': 'directory', - 'file_type': 'csv' - } - } - golden_type_meta = TypeMeta(name='GCSPath', properties={'bucket_type': 'directory', - 'file_type': 'csv'}) - self.assertEqual(TypeMeta.deserialize(component_dict), golden_type_meta) - - component_str = 'GCSPath' - golden_type_meta = TypeMeta(name='GCSPath') - self.assertEqual(TypeMeta.deserialize(component_str), golden_type_meta) - - - def test_eq(self): - type_a = TypeMeta(name='GCSPath', properties={'bucket_type': 'directory', - 'file_type': 'csv'}) - type_b = TypeMeta(name='GCSPath', properties={'bucket_type': 'directory', - 'file_type': 'tsv'}) - type_c = TypeMeta(name='GCSPatha', properties={'bucket_type': 'directory', - 'file_type': 'csv'}) - type_d = TypeMeta(name='GCSPath', properties={'bucket_type': 'directory', - 'file_type': 'csv'}) - self.assertNotEqual(type_a, type_b) - self.assertNotEqual(type_a, type_c) - self.assertEqual(type_a, type_d) - class TestComponentMeta(unittest.TestCase): @@ -53,34 +23,31 @@ def test_to_dict(self): description='foobar example', inputs=[ParameterMeta(name='input1', description='input1 desc', - param_type=TypeMeta(name='GCSPath', - properties={'bucket_type': 'directory', - 'file_type': 'csv' - } - ), + param_type={'GCSPath': { + 'bucket_type': 'directory', + 'file_type': 'csv' + }}, default='default1' ), ParameterMeta(name='input2', description='input2 desc', - param_type=TypeMeta(name='TFModel', - properties={'input_data': 'tensor', - 'version': '1.8.0' - } - ), + param_type={'TFModel': { + 'input_data': 'tensor', + 'version': '1.8.0' + }}, default='default2' ), ParameterMeta(name='input3', description='input3 desc', - param_type=TypeMeta(name='Integer'), + param_type='Integer', default='default3' ), ], outputs=[ParameterMeta(name='output1', description='output1 desc', - param_type=TypeMeta(name='Schema', - properties={'file_type': 'tsv' - } - ), + param_type={'Schema': { + 'file_type': 'tsv' + }}, default='default_output1' ) ] diff --git a/sdk/python/tests/dsl/pipeline_param_tests.py b/sdk/python/tests/dsl/pipeline_param_tests.py index 5d6288d8290..dd5a4000f47 100644 --- a/sdk/python/tests/dsl/pipeline_param_tests.py +++ b/sdk/python/tests/dsl/pipeline_param_tests.py @@ -15,7 +15,6 @@ from kubernetes.client.models import V1Container, V1EnvVar from kfp.dsl import PipelineParam from kfp.dsl._pipeline_param import _extract_pipelineparams, extract_pipelineparams_from_any -from kfp.dsl._metadata import TypeMeta import unittest @@ -69,9 +68,9 @@ def test_extract_pipelineparams_from_any(self): def test_extract_pipelineparam_with_types(self): """Test _extract_pipelineparams. """ - p1 = PipelineParam(name='param1', op_name='op1', param_type=TypeMeta(name='customized_type_a', properties={'property_a': 'value_a'})) - p2 = PipelineParam(name='param2', param_type=TypeMeta(name='customized_type_b')) - p3 = PipelineParam(name='param3', value='value3', param_type=TypeMeta(name='customized_type_c', properties={'property_c': 'value_c'})) + p1 = PipelineParam(name='param1', op_name='op1', param_type={'customized_type_a': {'property_a': 'value_a'}}) + p2 = PipelineParam(name='param2', param_type='customized_type_b') + p3 = PipelineParam(name='param3', value='value3', param_type={'customized_type_c': {'property_c': 'value_c'}}) stuff_chars = ' between ' payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3) params = _extract_pipelineparams(payload) diff --git a/sdk/python/tests/dsl/pipeline_tests.py b/sdk/python/tests/dsl/pipeline_tests.py index c8875bb2b47..89758d3b89d 100644 --- a/sdk/python/tests/dsl/pipeline_tests.py +++ b/sdk/python/tests/dsl/pipeline_tests.py @@ -14,7 +14,7 @@ import kfp from kfp.dsl import Pipeline, PipelineParam, ContainerOp, pipeline -from kfp.dsl._metadata import PipelineMeta, ParameterMeta, TypeMeta, _extract_pipeline_metadata +from kfp.dsl._metadata import PipelineMeta, ParameterMeta, _extract_pipeline_metadata from kfp.dsl.types import GCSPath, Integer import unittest @@ -70,8 +70,8 @@ def my_pipeline1(a: {'Schema': {'file_type': 'csv'}}='good', b: Integer()=12): pass golden_meta = PipelineMeta(name='p1', description='description1') - golden_meta.inputs.append(ParameterMeta(name='a', description='', param_type=TypeMeta(name='Schema', properties={'file_type': 'csv'}), default='good')) - golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type=TypeMeta(name='Integer', properties={'openapi_schema_validator': {"type": "integer"}}), default=12)) + golden_meta.inputs.append(ParameterMeta(name='a', description='', param_type={'Schema': {'file_type': 'csv'}}, default='good')) + golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type={'Integer': {'openapi_schema_validator': {"type": "integer"}}}, default=12)) pipeline_meta = _extract_pipeline_metadata(my_pipeline1) self.assertEqual(pipeline_meta, golden_meta) \ No newline at end of file