diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 8abe0154..82f1b8ad 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -651,7 +651,7 @@ def __raw_get(self, name: str) -> Any: def __eq__(self, other) -> bool: if type(self) is not type(other): - return False + return NotImplemented for field_name in self._betterproto.meta_by_field_name: self_val = self.__raw_get(field_name) diff --git a/tests/test_features.py b/tests/test_features.py index 322a310f..630ca6b6 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -18,6 +18,7 @@ List, Optional, ) +from unittest.mock import ANY import pytest @@ -727,3 +728,15 @@ class Spam(betterproto.Message): assert not Spam().is_set("bar") assert Spam(foo=True).is_set("foo") assert Spam(foo=True, bar=0).is_set("bar") + + +def test_equality_comparison(): + from tests.output_betterproto.bool import Test as TestMessage + + msg = TestMessage(value=True) + + assert msg == msg + assert msg == ANY + assert msg == TestMessage(value=True) + assert msg != 1 + assert msg != TestMessage(value=False)