diff --git a/sdk/python/kfp/components/modelbase.py b/sdk/python/kfp/components/modelbase.py index a3b2e3c92e9..7748bdc9936 100644 --- a/sdk/python/kfp/components/modelbase.py +++ b/sdk/python/kfp/components/modelbase.py @@ -282,7 +282,7 @@ def __repr__(self): return self.__class__.__name__ + '(' + ', '.join(param + '=' + repr(getattr(self, param)) for param in self._get_field_names()) + ')' def __eq__(self, other): - return self.__class__ == other.__class__ and {k: getattr(self, k) for k in self._get_field_names()} == {k: getattr(self, k) for k in other._get_field_names()} + return self.__class__ == other.__class__ and {k: getattr(self, k) for k in self._get_field_names()} == {k: getattr(other, k) for k in other._get_field_names()} def __ne__(self, other): return not self == other \ No newline at end of file diff --git a/sdk/python/tests/components/test_structure_model_base.py b/sdk/python/tests/components/test_structure_model_base.py index 5077d68e83e..95bd8bad0df 100644 --- a/sdk/python/tests/components/test_structure_model_base.py +++ b/sdk/python/tests/components/test_structure_model_base.py @@ -230,5 +230,20 @@ def test_handle_from_to_dict_for_union_dict_class(self): TestModel1.from_dict({'prop_0': '', 'prop_5': [val5.to_dict(), None]}) + def test_handle_comparisons(self): + class A(ModelBase): + def __init__(self, a, b): + super().__init__(locals()) + + self.assertEqual(A(1, 2), A(1, 2)) + self.assertNotEqual(A(1, 2), A(1, 3)) + + class B(ModelBase): + def __init__(self, a, b): + super().__init__(locals()) + + self.assertNotEqual(A(1, 2), B(1, 2)) + + if __name__ == '__main__': unittest.main()