Skip to content

Commit

Permalink
Add unit tests pipelineparam (#975)
Browse files Browse the repository at this point in the history
* add unit test to the pipelineparam with types
* create TypeMeta deserialize function, add comments
* strongly typed pipelineparamtuple
* addressing pr comments
  • Loading branch information
gaoning777 authored Mar 19, 2019
1 parent 754db1f commit 2accf41
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 38 deletions.
14 changes: 5 additions & 9 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,12 @@ def _process_args(self, raw_args, argument_inputs):
processed_args = list(map(str, raw_args))
for i, _ in enumerate(processed_args):
# unsanitized_argument_inputs stores a dict: string of sanitized param -> string of unsanitized param
matches = []
matches += _match_serialized_pipelineparam(str(processed_args[i]))
param_tuples = []
param_tuples += _match_serialized_pipelineparam(str(processed_args[i]))
unsanitized_argument_inputs = {}
for x in list(set(matches)):
if len(x) == 3 or (len(x) == 4 and x[3] == ''):
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2]))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2]))
elif len(x) == 4:
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2], TypeMeta.from_dict_or_str(x[3])))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3])))
for param_tuple in list(set(param_tuples)):
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(param_tuple.name), K8sHelper.sanitize_k8s_name(param_tuple.op), param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
if argument_inputs:
for param in argument_inputs:
if str(param) in unsanitized_argument_inputs:
Expand Down
40 changes: 27 additions & 13 deletions sdk/python/kfp/dsl/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,38 @@ def to_dict_or_str(self):
return {self.name: self.properties}

@staticmethod
def from_dict_or_str(json):
def from_dict_or_str(payload):
'''from_dict_or_str accepts a payload object and returns a TypeMeta instance
Args:
payload (str/dict): the payload could be a str or a dict
'''

type_meta = TypeMeta()
if isinstance(json, str) and '{' in json:
import ast
json = ast.literal_eval(json)
if isinstance(json, dict):
if not _check_valid_type_dict(json):
raise ValueError(json + ' is not a valid type string')
type_meta.name, type_meta.properties = list(json.items())[0]
if isinstance(payload, dict):
if not _check_valid_type_dict(payload):
raise ValueError(payload + ' is not a valid type string')
type_meta.name, type_meta.properties = list(payload.items())[0]
# Convert possible OrderedDict to dict
type_meta.properties = dict(type_meta.properties)
elif isinstance(json, str):
type_meta.name = json
elif isinstance(payload, str):
type_meta.name = payload
else:
raise ValueError('from_dict_or_str is expecting either dict or str.')
return type_meta

def serialize(self):
return str(self.to_dict_or_str())

@staticmethod
def deserialize(payload):
# If the payload is a string of a dict serialization, convert it back to a dict
try:
import ast
payload = ast.literal_eval(payload)
except:
pass
return TypeMeta.from_dict_or_str(payload)

class ParameterMeta(BaseMeta):
def __init__(self,
name: str,
Expand Down Expand Up @@ -128,13 +142,13 @@ def _annotation_to_typemeta(annotation):
TypeMeta
'''
if isinstance(annotation, BaseType):
arg_type = TypeMeta.from_dict_or_str(_instance_to_dict(annotation))
arg_type = TypeMeta.deserialize(_instance_to_dict(annotation))
elif isinstance(annotation, str):
arg_type = TypeMeta.from_dict_or_str(annotation)
arg_type = TypeMeta.deserialize(annotation)
elif isinstance(annotation, dict):
if not _check_valid_type_dict(annotation):
raise ValueError('Annotation ' + str(annotation) + ' is not a valid type dictionary.')
arg_type = TypeMeta.from_dict_or_str(annotation)
arg_type = TypeMeta.deserialize(annotation)
else:
return TypeMeta()
return arg_type
29 changes: 17 additions & 12 deletions sdk/python/kfp/dsl/_pipeline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,26 @@
# TODO: Move this to a separate class
# For now, this identifies a condition with only "==" operator supported.
ConditionOperator = namedtuple('ConditionOperator', 'operator operand1 operand2')
PipelineParamTuple = namedtuple('PipelineParamTuple', 'name op value type')

def _match_serialized_pipelineparam(payload: str):
"""_match_serialized_pipelineparam matches the serialized pipelineparam.
Args:
payloads (str): a string that contains the serialized pipelineparam.
Returns:
List(tuple())"""
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload)
if len(match) == 0:
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload)
return match
PipelineParamTuple
"""
matches = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload)
if len(matches) == 0:
matches = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload)
param_tuples = []
for match in matches:
if len(match) == 3:
param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=''))
elif len(match) == 4:
param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=match[3]))
return param_tuples

def _extract_pipelineparams(payloads: str or list[str]):
"""_extract_pipelineparam extract a list of PipelineParam instances from the payload string.
Expand All @@ -45,15 +53,12 @@ def _extract_pipelineparams(payloads: str or list[str]):
"""
if isinstance(payloads, str):
payloads = [payloads]
matches = []
param_tuples = []
for payload in payloads:
matches += _match_serialized_pipelineparam(payload)
param_tuples += _match_serialized_pipelineparam(payload)
pipeline_params = []
for x in list(set(matches)):
if len(x) == 3 or (len(x) == 4 and x[3] == ''):
pipeline_params.append(PipelineParam(x[1], x[0], x[2]))
elif len(x) == 4:
pipeline_params.append(PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3])))
for param_tuple in list(set(param_tuples)):
pipeline_params.append(PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
return pipeline_params

class PipelineParam(object):
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/tests/dsl/metadata_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

class TestTypeMeta(unittest.TestCase):
def test_from_dict_or_str(self):
def test_deserialize(self):
component_dict = {
'GCSPath': {
'bucket_type': 'directory',
Expand All @@ -25,11 +25,11 @@ def test_from_dict_or_str(self):
}
golden_type_meta = TypeMeta(name='GCSPath', properties={'bucket_type': 'directory',
'file_type': 'csv'})
self.assertEqual(TypeMeta.from_dict_or_str(component_dict), golden_type_meta)
self.assertEqual(TypeMeta.deserialize(component_dict), golden_type_meta)

component_str = 'GCSPath'
golden_type_meta = TypeMeta(name='GCSPath')
self.assertEqual(TypeMeta.from_dict_or_str(component_str), golden_type_meta)
self.assertEqual(TypeMeta.deserialize(component_str), golden_type_meta)


def test_eq(self):
Expand Down
15 changes: 14 additions & 1 deletion sdk/python/tests/dsl/pipeline_param_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from kfp.dsl import PipelineParam
from kfp.dsl._pipeline_param import _extract_pipelineparams
from kfp.dsl._metadata import TypeMeta
import unittest


Expand Down Expand Up @@ -51,4 +52,16 @@ def test_extract_pipelineparam(self):
params = _extract_pipelineparams(payload)
self.assertListEqual([p1, p2, p3], params)

#TODO: add more unit tests to cover real type instances
def test_extract_pipelineparam_with_types(self):
"""Test _extract_pipelineparams. """
p1 = PipelineParam(name='param1', op_name='op1', param_type=TypeMeta(name='customized_type_a', properties={'property_a': 'value_a'}))
p2 = PipelineParam(name='param2', param_type=TypeMeta(name='customized_type_b'))
p3 = PipelineParam(name='param3', value='value3', param_type=TypeMeta(name='customized_type_c', properties={'property_c': 'value_c'}))
stuff_chars = ' between '
payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3)
params = _extract_pipelineparams(payload)
self.assertListEqual([p1, p2, p3], params)
# Expecting the _extract_pipelineparam to dedup the pipelineparams among all the payloads.
payload = [str(p1) + stuff_chars + str(p2), str(p2) + stuff_chars + str(p3)]
params = _extract_pipelineparams(payload)
self.assertListEqual([p1, p2, p3], params)

0 comments on commit 2accf41

Please sign in to comment.