Skip to content

Commit

Permalink
SDK - Components - Fixed ModelBase comparison bug (kubeflow#1874)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ark-kun authored Aug 21, 2019
1 parent 2622c67 commit 553885f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sdk/python/kfp/components/modelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(param + '=' + repr(getattr(self, param)) for param in self._get_field_names()) + ')'

def __eq__(self, other):
return self.__class__ == other.__class__ and {k: getattr(self, k) for k in self._get_field_names()} == {k: getattr(self, k) for k in other._get_field_names()}
return self.__class__ == other.__class__ and {k: getattr(self, k) for k in self._get_field_names()} == {k: getattr(other, k) for k in other._get_field_names()}

def __ne__(self, other):
return not self == other
15 changes: 15 additions & 0 deletions sdk/python/tests/components/test_structure_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,20 @@ def test_handle_from_to_dict_for_union_dict_class(self):
TestModel1.from_dict({'prop_0': '', 'prop_5': [val5.to_dict(), None]})


def test_handle_comparisons(self):
class A(ModelBase):
def __init__(self, a, b):
super().__init__(locals())

self.assertEqual(A(1, 2), A(1, 2))
self.assertNotEqual(A(1, 2), A(1, 3))

class B(ModelBase):
def __init__(self, a, b):
super().__init__(locals())

self.assertNotEqual(A(1, 2), B(1, 2))


if __name__ == '__main__':
unittest.main()

0 comments on commit 553885f

Please sign in to comment.