Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add default value type checking #1407

Merged
merged 7 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ matrix:
- language: python
python: "3.6"
env: TOXENV=py36
install: pip3 install jsonschema==3.0.1
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
script:
# DSL tests
- cd $TRAVIS_BUILD_DIR/sdk/python
Expand Down
16 changes: 15 additions & 1 deletion sdk/python/kfp/dsl/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def serialize(self):

@staticmethod
def deserialize(payload):
# If the payload is a string of a dict serialization, convert it back to a dict
'''deserialize expects two types of input: dict and str
1) If the payload is a string of a dict serialization, convert it back to a dict
2) If the payload is a string, the type is named as such with no properties.
3) If the payload is a dict, the type name and properties are extracted. '''
try:
import ast
payload = ast.literal_eval(payload)
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -138,6 +141,9 @@ def _annotation_to_typemeta(annotation):
'''_annotation_to_type_meta converts an annotation to an instance of TypeMeta
Args:
annotation(BaseType/str/dict): input/output annotations
BaseType: registered in kfp.dsl.types
str: either a string of a dict serialization or a string of the type name
dict: type name and properties. note that the properties values can be dict.
Returns:
TypeMeta
'''
Expand Down Expand Up @@ -221,6 +227,14 @@ def _extract_pipeline_metadata(func):
arg_default = arg_defaults[arg] if arg in arg_defaults else None
if arg in annotations:
arg_type = _annotation_to_typemeta(annotations[arg])
if 'openapi_schema_validator' in arg_type.properties and arg_default is not None:
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
from jsonschema import validate
import json
schema_object = arg_type.properties['openapi_schema_validator']
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(schema_object, str):
# In case the property value for the schema validator is a string instead of a dict.
schema_object = json.loads(arg_type.properties['openapi_schema_validator'])
validate(instance=arg_default, schema=schema_object)
pipeline_meta.inputs.append(ParameterMeta(name=arg, description='', param_type=arg_type, default=arg_default))

#TODO: add descriptions to the metadata
Expand Down
81 changes: 46 additions & 35 deletions sdk/python/kfp/dsl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,76 @@ class BaseType:

# Primitive Types
class Integer(BaseType):
openapi_schema_validator = {
"type": "integer"
}
def __init__(self):
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
self.openapi_schema_validator = {
"type": "integer"
}

class String(BaseType):
openapi_schema_validator = {
"type": "string"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "string"
}

class Float(BaseType):
openapi_schema_validator = {
"type": "number"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "number"
}

class Bool(BaseType):
openapi_schema_validator = {
"type": "boolean"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "boolean"
}

class List(BaseType):
openapi_schema_validator = {
"type": "array"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "array"
}

class Dict(BaseType):
openapi_schema_validator = {
"type": "object",
}
def __init__(self):
self.openapi_schema_validator = {
"type": "object",
}

# GCP Types
class GCSPath(BaseType):
openapi_schema_validator = {
"type": "string",
"pattern": "^gs://.*$"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "string",
"pattern": "^gs://.*$"
}

class GCRPath(BaseType):
openapi_schema_validator = {
"type": "string",
"pattern": "^.*gcr\\.io/.*$"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "string",
"pattern": "^.*gcr\\.io/.*$"
}

class GCPRegion(BaseType):
openapi_schema_validator = {
"type": "string"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "string"
}

class GCPProjectID(BaseType):
'''MetaGCPProjectID: GCP project id'''
openapi_schema_validator = {
"type": "string"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "string"
}

# General Types
class LocalPath(BaseType):
#TODO: add restriction to path
openapi_schema_validator = {
"type": "string"
}
def __init__(self):
self.openapi_schema_validator = {
"type": "string"
}

class InconsistentTypeException(Exception):
'''InconsistencyTypeException is raised when two types are not consistent'''
Expand Down
1 change: 1 addition & 0 deletions sdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'cloudpickle',
'kfp-server-api >= 0.1.18, < 0.1.19', #Update the upper version whenever a new version of the kfp-server-api package is released. Update the lower version when there is a breaking change in kfp-server-api.
'argo-models == 2.2.1a', #2.2.1a is equivalent to argo 2.2.1
'jsonschema >= 3.0.1'
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
]

setup(
Expand Down
36 changes: 34 additions & 2 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_py_param_substitutions(self):
def test_type_checking_with_consistent_types(self):
"""Test type check pipeline parameters against component metadata."""
@component
def a_op(field_m: {'GCSPath': {'path_type': 'file', 'file_type':'tsv'}}, field_o: 'Integer'):
def a_op(field_m: {'GCSPath': {'path_type': 'file', 'file_type':'tsv'}}, field_o: Integer()):
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
return ContainerOp(
name = 'operator a',
image = 'gcr.io/ml-pipeline/component-b',
Expand Down Expand Up @@ -394,7 +394,7 @@ def my_pipeline(a: {'GCSPath': {'path_type':'file', 'file_type': 'tsv'}}='good',
def test_type_checking_with_inconsistent_types(self):
"""Test type check pipeline parameters against component metadata."""
@component
def a_op(field_m: {'GCSPath': {'path_type': 'file', 'file_type':'tsv'}}, field_o: 'Integer'):
def a_op(field_m: {'GCSPath': {'path_type': 'file', 'file_type':'tsv'}}, field_o: Integer()):
return ContainerOp(
name = 'operator a',
image = 'gcr.io/ml-pipeline/component-b',
Expand Down Expand Up @@ -423,6 +423,38 @@ def my_pipeline(a: {'GCSPath': {'path_type':'file', 'file_type': 'csv'}}='good',
finally:
shutil.rmtree(tmpdir)

def test_type_checking_with_json_schema(self):
"""Test type check pipeline parameters against the json schema."""
@component
def a_op(field_m: {'GCRPath': {'openapi_schema_validator': {"type": "string", "pattern": "^.*gcr\\.io/.*$"}}}, field_o: 'Integer'):
return ContainerOp(
name = 'operator a',
image = 'gcr.io/ml-pipeline/component-b',
arguments = [
'--field-l', field_m,
'--field-o', field_o,
],
)

@pipeline(
name='p1',
description='description1'
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
)
def my_pipeline(a: {'GCRPath': {'openapi_schema_validator': {"type": "string", "pattern": "^.*gcr\\.io/.*$"}}}='good', b: 'Integer'=12):
a_op(field_m=a, field_o=b)

test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
sys.path.append(test_data_dir)
tmpdir = tempfile.mkdtemp()
try:
simple_package_path = os.path.join(tmpdir, 'simple.tar.gz')
import jsonschema
with self.assertRaises(jsonschema.exceptions.ValidationError):
compiler.Compiler().compile(my_pipeline, simple_package_path, type_check=True)
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved

finally:
shutil.rmtree(tmpdir)

def test_compile_pipeline_with_after(self):
def op():
return dsl.ContainerOp(
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/tests/dsl/component_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def componentA(a: {'ArtifactA': {'file_type': 'csv'}}, b: Integer() = 12, c: {'A

golden_meta = ComponentMeta(name='componentA', description='')
golden_meta.inputs.append(ParameterMeta(name='a', description='', param_type=TypeMeta(name='ArtifactA', properties={'file_type': 'csv'})))
golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type=TypeMeta(name='Integer'), default=12))
golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type=TypeMeta(name='Integer', properties={'openapi_schema_validator': {"type": "integer"}}), default=12))
golden_meta.inputs.append(ParameterMeta(name='c', description='', param_type=TypeMeta(name='ArtifactB', properties={'path_type':'file', 'file_type': 'tsv'}), default='gs://hello/world'))
golden_meta.outputs.append(ParameterMeta(name='model', description='', param_type=TypeMeta(name='Integer')))
golden_meta.outputs.append(ParameterMeta(name='model', description='', param_type=TypeMeta(name='Integer', properties={'openapi_schema_validator': {"type": "integer"}})))

self.assertEqual(containerOp._metadata, golden_meta)

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/tests/dsl/pipeline_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def my_pipeline1(a: {'Schema': {'file_type': 'csv'}}='good', b: Integer()=12):

golden_meta = PipelineMeta(name='p1', description='description1')
golden_meta.inputs.append(ParameterMeta(name='a', description='', param_type=TypeMeta(name='Schema', properties={'file_type': 'csv'}), default='good'))
golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type=TypeMeta(name='Integer'), default=12))
golden_meta.inputs.append(ParameterMeta(name='b', description='', param_type=TypeMeta(name='Integer', properties={'openapi_schema_validator': {"type": "integer"}}), default=12))

pipeline_meta = _extract_pipeline_metadata(my_pipeline1)
self.assertEqual(pipeline_meta, golden_meta)
5 changes: 4 additions & 1 deletion sdk/python/tests/dsl/type_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def test_class_to_dict(self):
gcspath_dict = _instance_to_dict(GCSPath())
golden_dict = {
'GCSPath': {

'openapi_schema_validator': {
"type": "string",
"pattern": "^gs://.*$"
}
}
}
self.assertEqual(golden_dict, gcspath_dict)
Expand Down