Skip to content

Commit

Permalink
add core types and type checking function (#886)
Browse files Browse the repository at this point in the history
* add core types and type checking function

* fix unit test bug

* avoid defining dynamic classes

* typo fix

* use python struct for the openapi schema

* update param name in the check_type functions
remove schema validators for GCRPath, and adjust for GCRPath, GCSPath
change _check_valid_dict to _check_valid_type_dict to avoid confusion
fix typo in the comments
adjust function order for readability
  • Loading branch information
gaoning777 authored and k8s-ci-robot committed Mar 5, 2019
1 parent 806e123 commit 02ab7b7
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 1 deletion.
166 changes: 166 additions & 0 deletions sdk/python/kfp/dsl/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

class BaseType:
'''MetaType is a base type for all scalar and artifact types.
'''
pass

# Primitive Types
class Integer(BaseType):
openapi_schema_validator = {
"type": "integer"
}

class String(BaseType):
openapi_schema_validator = {
"type": "string"
}

class Float(BaseType):
openapi_schema_validator = {
"type": "number"
}

class Bool(BaseType):
openapi_schema_validator = {
"type": "boolean"
}

class List(BaseType):
openapi_schema_validator = {
"type": "array"
}

class Dict(BaseType):
openapi_schema_validator = {
"type": "object",
}

# GCP Types
class GCSPath(BaseType):
openapi_schema_validator = {
"type": "string",
"pattern": "^gs://.*$"
}

def __init__(self, path_type='', file_type=''):
'''
Args
:param path_type: describes the paths, for example, bucket, directory, file, etc
:param file_type: describes the files, for example, JSON, CSV, etc.
'''
self.path_type = path_type
self.file_type = file_type

class GCRPath(BaseType):
openapi_schema_validator = {
"type": "string",
"pattern": "^.*gcr\\.io/.*$"
}

class GCPRegion(BaseType):
openapi_schema_validator = {
"type": "string"
}

class GCPProjectID(BaseType):
'''MetaGCPProjectID: GCP project id'''
openapi_schema_validator = {
"type": "string"
}

# General Types
class LocalPath(BaseType):
#TODO: add restriction to path
openapi_schema_validator = {
"type": "string"
}

class InconsistentTypeException(Exception):
'''InconsistencyTypeException is raised when two types are not consistent'''
pass

def check_types(checked_type, expected_type):
'''check_types checks the type consistency.
For each of the attribute in checked_type, there is the same attribute in expected_type with the same value.
However, expected_type could contain more attributes that checked_type does not contain.
Args:
checked_type (BaseType/str/dict): it describes a type from the upstream component output
expected_type (BaseType/str/dict): it describes a type from the downstream component input
'''
if isinstance(checked_type, BaseType):
checked_type = _instance_to_dict(checked_type)
elif isinstance(checked_type, str):
checked_type = _str_to_dict(checked_type)
if isinstance(expected_type, BaseType):
expected_type = _instance_to_dict(expected_type)
elif isinstance(expected_type, str):
expected_type = _str_to_dict(expected_type)
return _check_dict_types(checked_type, expected_type)

def _check_valid_type_dict(payload):
'''_check_valid_type_dict checks whether a dict is a correct serialization of a type
Args:
payload(dict)
'''
if not isinstance(payload, dict) or len(payload) != 1:
return False
for type_name in payload:
if not isinstance(payload[type_name], dict):
return False
property_types = (int, str, float, bool)
for property_name in payload[type_name]:
if not isinstance(property_name, property_types) or not isinstance(payload[type_name][property_name], property_types):
return False
return True

def _instance_to_dict(instance):
'''_instance_to_dict serializes the type instance into a python dictionary
Args:
instance(BaseType): An instance that describes a type
Return:
dict
'''
return {type(instance).__name__: instance.__dict__}

def _str_to_dict(payload):
import json
json_dict = json.loads(payload)
if not _check_valid_type_dict(json_dict):
raise ValueError(payload + ' is not a valid type string')
return json_dict

def _check_dict_types(checked_type, expected_type):
'''_check_type_types checks the type consistency.
Args:
checked_type (dict): A dict that describes a type from the upstream component output
expected_type (dict): A dict that describes a type from the downstream component input
'''
checked_type_name,_ = list(checked_type.items())[0]
expected_type_name,_ = list(expected_type.items())[0]
if checked_type_name != expected_type_name:
return False
type_name = checked_type_name
for type_property in checked_type[type_name]:
if type_property not in expected_type[type_name]:
print(type_name + ' has a property ' + str(type_property) + ' that the latter does not.')
return False
if checked_type[type_name][type_property] != expected_type[type_name][type_property]:
print(type_name + ' has a property ' + str(type_property) + ' with value: ' +
str(checked_type[type_name][type_property]) + ' and ' +
str(expected_type[type_name][type_property]))
return False
return True
3 changes: 2 additions & 1 deletion sdk/python/tests/dsl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import pipeline_param_tests
import container_op_tests
import ops_group_tests

import type_tests

if __name__ == '__main__':
suite = unittest.TestSuite()
suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(pipeline_param_tests))
suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(pipeline_tests))
suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(container_op_tests))
suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(ops_group_tests))
suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(type_tests))
runner = unittest.TextTestRunner()
if not runner.run(suite).wasSuccessful():
sys.exit(1)
Expand Down
81 changes: 81 additions & 0 deletions sdk/python/tests/dsl/type_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from kfp.dsl._types import _instance_to_dict, _str_to_dict, check_types, GCSPath
import unittest

class TestTypes(unittest.TestCase):

def test_class_to_dict(self):
"""Test _class_to_dict function."""
gcspath_dict = _instance_to_dict(GCSPath(path_type='file', file_type='csv'))
golden_dict = {
'GCSPath': {
'path_type': 'file',
'file_type': 'csv',
}
}
self.assertEqual(golden_dict, gcspath_dict)

def test_str_to_dict(self):
gcspath_str = '{"GCSPath": {"file_type": "csv", "path_type": "file"}}'
gcspath_dict = _str_to_dict(gcspath_str)
golden_dict = {
'GCSPath': {
'path_type': 'file',
'file_type': 'csv'
}
}
self.assertEqual(golden_dict, gcspath_dict)
gcspath_str = '{"file_type": "csv", "path_type": "file"}'
with self.assertRaises(ValueError):
_str_to_dict(gcspath_str)

def test_check_types(self):
#Core types
typeA = GCSPath(path_type='file', file_type='csv')
typeB = GCSPath(path_type='file', file_type='csv')
self.assertTrue(check_types(typeA, typeB))
typeC = GCSPath(path_type='file', file_type='tsv')
self.assertFalse(check_types(typeA, typeC))

# Custom types
typeA = {
'A':{
'X': 'value1',
'Y': 'value2'
}
}
typeB = {
'B':{
'X': 'value1',
'Y': 'value2'
}
}
typeC = {
'A':{
'X': 'value1'
}
}
typeD = {
'A':{
'X': 'value1',
'Y': 'value3'
}
}
self.assertFalse(check_types(typeA, typeB))
self.assertFalse(check_types(typeA, typeC))
self.assertTrue(check_types(typeC, typeA))
self.assertFalse(check_types(typeA, typeD))

0 comments on commit 02ab7b7

Please sign in to comment.