Skip to content

Commit

Permalink
Fix sample test failure because of the type information in the pipeli…
Browse files Browse the repository at this point in the history
…neparam (#972)

* fix bug: op_to_template resolve the raw arguments by mapping to the argument_inputs but the argument_inputs lost the type information

* fix type pattern matching

* convert orderedDict to dict from the component module
  • Loading branch information
gaoning777 authored Mar 15, 2019
1 parent cc3f214 commit 754db1f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
10 changes: 6 additions & 4 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,12 @@ def _process_args(self, raw_args, argument_inputs):
matches += _match_serialized_pipelineparam(str(processed_args[i]))
unsanitized_argument_inputs = {}
for x in list(set(matches)):
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2]))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2]))

if len(x) == 3 or (len(x) == 4 and x[3] == ''):
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2]))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2]))
elif len(x) == 4:
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2], TypeMeta.from_dict_or_str(x[3])))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3])))
if argument_inputs:
for param in argument_inputs:
if str(param) in unsanitized_argument_inputs:
Expand Down Expand Up @@ -257,7 +260,6 @@ def _build_conventional_artifact(name, path):
}
},
}

processed_arguments = self._process_args(op.arguments, op.argument_inputs)
processed_command = self._process_args(op.command, op.argument_inputs)

Expand Down
5 changes: 5 additions & 0 deletions sdk/python/kfp/dsl/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def to_dict_or_str(self):
@staticmethod
def from_dict_or_str(json):
type_meta = TypeMeta()
if isinstance(json, str) and '{' in json:
import ast
json = ast.literal_eval(json)
if isinstance(json, dict):
if not _check_valid_type_dict(json):
raise ValueError(json + ' is not a valid type string')
type_meta.name, type_meta.properties = list(json.items())[0]
# Convert possible OrderedDict to dict
type_meta.properties = dict(type_meta.properties)
elif isinstance(json, str):
type_meta.name = json
return type_meta
Expand Down
12 changes: 9 additions & 3 deletions sdk/python/kfp/dsl/_pipeline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _match_serialized_pipelineparam(payload: str):
Returns:
List(tuple())"""
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?)}}', payload)
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload)
if len(match) == 0:
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload)
return match
Expand All @@ -48,7 +48,13 @@ def _extract_pipelineparams(payloads: str or list[str]):
matches = []
for payload in payloads:
matches += _match_serialized_pipelineparam(payload)
return [PipelineParam(x[1], x[0], x[2]) for x in list(set(matches))]
pipeline_params = []
for x in list(set(matches)):
if len(x) == 3 or (len(x) == 4 and x[3] == ''):
pipeline_params.append(PipelineParam(x[1], x[0], x[2]))
elif len(x) == 4:
pipeline_params.append(PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3])))
return pipeline_params

class PipelineParam(object):
"""Representing a future value that is passed between pipeline components.
Expand Down Expand Up @@ -104,7 +110,7 @@ def __str__(self):
if self.param_type is None:
return '{{pipelineparam:op=%s;name=%s;value=%s}}' % (op_name, self.name, value)
else:
return '{{pipelineparam:op=%s;name=%s;value=%s;type=%s}}' % (op_name, self.name, value, self.param_type.serialize())
return '{{pipelineparam:op=%s;name=%s;value=%s;type=%s;}}' % (op_name, self.name, value, self.param_type.serialize())

def __repr__(self):
return str({self.__class__.__name__: self.__dict__})
Expand Down
10 changes: 6 additions & 4 deletions sdk/python/tests/dsl/pipeline_param_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def test_str_repr(self):
"""Test string representation."""

p = PipelineParam(name='param1', op_name='op1')
self.assertEqual('{{pipelineparam:op=op1;name=param1;value=;type=}}', str(p))
self.assertEqual('{{pipelineparam:op=op1;name=param1;value=;type=;}}', str(p))

p = PipelineParam(name='param2')
self.assertEqual('{{pipelineparam:op=;name=param2;value=;type=}}', str(p))
self.assertEqual('{{pipelineparam:op=;name=param2;value=;type=;}}', str(p))

p = PipelineParam(name='param3', value='value3')
self.assertEqual('{{pipelineparam:op=;name=param3;value=value3;type=}}', str(p))
self.assertEqual('{{pipelineparam:op=;name=param3;value=value3;type=;}}', str(p))

def test_extract_pipelineparam(self):
"""Test _extract_pipeleineparam."""
Expand All @@ -49,4 +49,6 @@ def test_extract_pipelineparam(self):
self.assertListEqual([p1, p2, p3], params)
payload = [str(p1) + stuff_chars + str(p2), str(p2) + stuff_chars + str(p3)]
params = _extract_pipelineparams(payload)
self.assertListEqual([p1, p2, p3], params)
self.assertListEqual([p1, p2, p3], params)

#TODO: add more unit tests to cover real type instances

0 comments on commit 754db1f

Please sign in to comment.