diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index 21a9401c9e7..f5155194837 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -219,16 +219,12 @@ def _process_args(self, raw_args, argument_inputs): processed_args = list(map(str, raw_args)) for i, _ in enumerate(processed_args): # unsanitized_argument_inputs stores a dict: string of sanitized param -> string of unsanitized param - matches = [] - matches += _match_serialized_pipelineparam(str(processed_args[i])) + param_tuples = [] + param_tuples += _match_serialized_pipelineparam(str(processed_args[i])) unsanitized_argument_inputs = {} - for x in list(set(matches)): - if len(x) == 3 or (len(x) == 4 and x[3] == ''): - sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2])) - unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2])) - elif len(x) == 4: - sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2], TypeMeta.from_dict_or_str(x[3]))) - unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3]))) + for param_tuple in list(set(param_tuples)): + sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(param_tuple.name), K8sHelper.sanitize_k8s_name(param_tuple.op), param_tuple.value, TypeMeta.deserialize(param_tuple.type))) + unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type))) if argument_inputs: for param in argument_inputs: if str(param) in unsanitized_argument_inputs: diff --git a/sdk/python/kfp/dsl/_metadata.py b/sdk/python/kfp/dsl/_metadata.py index ce4633ccd42..2188e9086a9 100644 --- a/sdk/python/kfp/dsl/_metadata.py +++ b/sdk/python/kfp/dsl/_metadata.py @@ -46,24 +46,38 @@ def to_dict_or_str(self): return {self.name: self.properties} @staticmethod - def from_dict_or_str(json): + def from_dict_or_str(payload): + '''from_dict_or_str accepts a payload object and returns a TypeMeta instance + Args: + payload (str/dict): the payload could be a str or a dict + ''' + type_meta = TypeMeta() - if isinstance(json, str) and '{' in json: - import ast - json = ast.literal_eval(json) - if isinstance(json, dict): - if not _check_valid_type_dict(json): - raise ValueError(json + ' is not a valid type string') - type_meta.name, type_meta.properties = list(json.items())[0] + if isinstance(payload, dict): + if not _check_valid_type_dict(payload): + raise ValueError(payload + ' is not a valid type string') + type_meta.name, type_meta.properties = list(payload.items())[0] # Convert possible OrderedDict to dict type_meta.properties = dict(type_meta.properties) - elif isinstance(json, str): - type_meta.name = json + elif isinstance(payload, str): + type_meta.name = payload + else: + raise ValueError('from_dict_or_str is expecting either dict or str.') return type_meta def serialize(self): return str(self.to_dict_or_str()) + @staticmethod + def deserialize(payload): + # If the payload is a string of a dict serialization, convert it back to a dict + try: + import ast + payload = ast.literal_eval(payload) + except: + pass + return TypeMeta.from_dict_or_str(payload) + class ParameterMeta(BaseMeta): def __init__(self, name: str, @@ -128,13 +142,13 @@ def _annotation_to_typemeta(annotation): TypeMeta ''' if isinstance(annotation, BaseType): - arg_type = TypeMeta.from_dict_or_str(_instance_to_dict(annotation)) + arg_type = TypeMeta.deserialize(_instance_to_dict(annotation)) elif isinstance(annotation, str): - arg_type = TypeMeta.from_dict_or_str(annotation) + arg_type = TypeMeta.deserialize(annotation) elif isinstance(annotation, dict): if not _check_valid_type_dict(annotation): raise ValueError('Annotation ' + str(annotation) + ' is not a valid type dictionary.') - arg_type = TypeMeta.from_dict_or_str(annotation) + arg_type = TypeMeta.deserialize(annotation) else: return TypeMeta() return arg_type diff --git a/sdk/python/kfp/dsl/_pipeline_param.py b/sdk/python/kfp/dsl/_pipeline_param.py index 77295cd8732..9ed513465ad 100644 --- a/sdk/python/kfp/dsl/_pipeline_param.py +++ b/sdk/python/kfp/dsl/_pipeline_param.py @@ -21,6 +21,7 @@ # TODO: Move this to a separate class # For now, this identifies a condition with only "==" operator supported. ConditionOperator = namedtuple('ConditionOperator', 'operator operand1 operand2') +PipelineParamTuple = namedtuple('PipelineParamTuple', 'name op value type') def _match_serialized_pipelineparam(payload: str): """_match_serialized_pipelineparam matches the serialized pipelineparam. @@ -28,11 +29,18 @@ def _match_serialized_pipelineparam(payload: str): payloads (str): a string that contains the serialized pipelineparam. Returns: - List(tuple())""" - match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload) - if len(match) == 0: - match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload) - return match + PipelineParamTuple + """ + matches = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload) + if len(matches) == 0: + matches = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload) + param_tuples = [] + for match in matches: + if len(match) == 3: + param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type='')) + elif len(match) == 4: + param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=match[3])) + return param_tuples def _extract_pipelineparams(payloads: str or list[str]): """_extract_pipelineparam extract a list of PipelineParam instances from the payload string. @@ -45,15 +53,12 @@ def _extract_pipelineparams(payloads: str or list[str]): """ if isinstance(payloads, str): payloads = [payloads] - matches = [] + param_tuples = [] for payload in payloads: - matches += _match_serialized_pipelineparam(payload) + param_tuples += _match_serialized_pipelineparam(payload) pipeline_params = [] - for x in list(set(matches)): - if len(x) == 3 or (len(x) == 4 and x[3] == ''): - pipeline_params.append(PipelineParam(x[1], x[0], x[2])) - elif len(x) == 4: - pipeline_params.append(PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3]))) + for param_tuple in list(set(param_tuples)): + pipeline_params.append(PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type))) return pipeline_params class PipelineParam(object): diff --git a/sdk/python/tests/dsl/metadata_tests.py b/sdk/python/tests/dsl/metadata_tests.py index 24a174ce257..52c5c161471 100644 --- a/sdk/python/tests/dsl/metadata_tests.py +++ b/sdk/python/tests/dsl/metadata_tests.py @@ -16,7 +16,7 @@ import unittest class TestTypeMeta(unittest.TestCase): - def test_from_dict_or_str(self): + def test_deserialize(self): component_dict = { 'GCSPath': { 'bucket_type': 'directory', @@ -25,11 +25,11 @@ def test_from_dict_or_str(self): } golden_type_meta = TypeMeta(name='GCSPath', properties={'bucket_type': 'directory', 'file_type': 'csv'}) - self.assertEqual(TypeMeta.from_dict_or_str(component_dict), golden_type_meta) + self.assertEqual(TypeMeta.deserialize(component_dict), golden_type_meta) component_str = 'GCSPath' golden_type_meta = TypeMeta(name='GCSPath') - self.assertEqual(TypeMeta.from_dict_or_str(component_str), golden_type_meta) + self.assertEqual(TypeMeta.deserialize(component_str), golden_type_meta) def test_eq(self): diff --git a/sdk/python/tests/dsl/pipeline_param_tests.py b/sdk/python/tests/dsl/pipeline_param_tests.py index 7d18147445a..1025bdf9231 100644 --- a/sdk/python/tests/dsl/pipeline_param_tests.py +++ b/sdk/python/tests/dsl/pipeline_param_tests.py @@ -15,6 +15,7 @@ from kfp.dsl import PipelineParam from kfp.dsl._pipeline_param import _extract_pipelineparams +from kfp.dsl._metadata import TypeMeta import unittest @@ -51,4 +52,16 @@ def test_extract_pipelineparam(self): params = _extract_pipelineparams(payload) self.assertListEqual([p1, p2, p3], params) - #TODO: add more unit tests to cover real type instances \ No newline at end of file + def test_extract_pipelineparam_with_types(self): + """Test _extract_pipelineparams. """ + p1 = PipelineParam(name='param1', op_name='op1', param_type=TypeMeta(name='customized_type_a', properties={'property_a': 'value_a'})) + p2 = PipelineParam(name='param2', param_type=TypeMeta(name='customized_type_b')) + p3 = PipelineParam(name='param3', value='value3', param_type=TypeMeta(name='customized_type_c', properties={'property_c': 'value_c'})) + stuff_chars = ' between ' + payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3) + params = _extract_pipelineparams(payload) + self.assertListEqual([p1, p2, p3], params) + # Expecting the _extract_pipelineparam to dedup the pipelineparams among all the payloads. + payload = [str(p1) + stuff_chars + str(p2), str(p2) + stuff_chars + str(p3)] + params = _extract_pipelineparams(payload) + self.assertListEqual([p1, p2, p3], params) \ No newline at end of file