Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Adds union support check when passing in inputs
Browse files Browse the repository at this point in the history
This should fix #170. If an input function was annotated
with a union, we would barf. Now it wont.

This still means downstream nodes need to have the same
type signature -- this doesn't change any behavior there.
  • Loading branch information
skrawcz authored and elijahbenizzy committed Aug 15, 2022
1 parent 751fce0 commit e7be742
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions hamilton/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
return True
elif typing_inspect.is_generic_type(node_type) and typing_inspect.get_origin(node_type) == type(input_value):
return True
elif typing_inspect.is_union_type(node_type):
union_types = typing_inspect.get_args(node_type)
return any([SimplePythonDataFrameGraphAdapter.check_input_type(ut, input_value) for ut in union_types])
elif node_type == type(input_value):
return True
return False
Expand Down
10 changes: 9 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Any:
(int, 1),
(float, 1.0),
(str, 'abc'),
(typing.Union[int, pd.Series], pd.Series([1,2,3])),
(typing.Union[int, pd.Series], 1),
], ids=[
'test-any',
'test-subclass',
Expand All @@ -72,7 +74,9 @@ def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Any:
'test-type-match-list',
'test-type-match-int',
'test-type-match-float',
'test-type-match-str'
'test-type-match-str',
'test-union-match-series',
'test-union-match-int',
])
def test_SimplePythonDataFrameGraphAdapter_check_input_type_match(node_type, input_value):
"""Tests check_input_type of SimplePythonDataFrameGraphAdapter"""
Expand All @@ -90,6 +94,8 @@ def test_SimplePythonDataFrameGraphAdapter_check_input_type_match(node_type, inp
(int, 1.0),
(float, 1),
(str, 0),
(typing.Union[int, pd.Series], pd.DataFrame({'a': [1, 2, 3]})),
(typing.Union[int, pd.Series], 1.0),
], ids=[
'test-subclass',
'test-generic-list',
Expand All @@ -99,6 +105,8 @@ def test_SimplePythonDataFrameGraphAdapter_check_input_type_match(node_type, inp
'test-type-match-int',
'test-type-match-float',
'test-type-match-str',
'test-union-mismatch-dataframe',
'test-union-mismatch-float',
])
def test_SimplePythonDataFrameGraphAdapter_check_input_type_mismatch(node_type, input_value):
"""Tests check_input_type of SimplePythonDataFrameGraphAdapter"""
Expand Down

0 comments on commit e7be742

Please sign in to comment.