diff --git a/thrift/lib/py3/test/auto_migrate/lists.py b/thrift/lib/py3/test/auto_migrate/lists.py index 1b309316c98..caa4a0c7e04 100644 --- a/thrift/lib/py3/test/auto_migrate/lists.py +++ b/thrift/lib/py3/test/auto_migrate/lists.py @@ -165,6 +165,14 @@ def test_comparisons(self) -> None: # got `List__i32`. self.assertGreaterEqual(x, x2) + @brokenInAutoMigrate() + def test_no_raise_on_type_error(self) -> None: + t_list = I32List([1, 2, 3, 4]) + self.assertFalse(t_list == easy()) + self.assertFalse(easy() == t_list) + self.assertNotEqual(t_list, easy()) + self.assertNotEqual(easy(), t_list) + @brokenInAutoMigrate() def test_is_container(self) -> None: self.assertIsInstance(int_list, Container) diff --git a/thrift/lib/py3/types.pyx b/thrift/lib/py3/types.pyx index c20cf7983a5..c69d54a8a85 100644 --- a/thrift/lib/py3/types.pyx +++ b/thrift/lib/py3/types.pyx @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from collections.abc import Iterable, Mapping, Set as pySet +from collections.abc import Iterable, Mapping, Set as pySet, Sized import enum import itertools import warnings @@ -46,29 +46,29 @@ __all__ = ['Struct', 'BadEnum', 'Union', 'Enum', 'Flag', 'EnumMeta'] Object = cython.fused_type(Struct, GeneratedError) +cdef list_eq(List self, object other): + if ( + not isinstance(other, Iterable) or + not isinstance(other, Sized) or + len(self) != len(other) + ): + return False -cdef list_compare(object first, object second, int op): - """ Take either Py_EQ or Py_LT, everything else is derived """ - if not (isinstance(first, Iterable) and isinstance(second, Iterable)): - if op == Py_EQ: + for x, y in zip(self, other): + if x != y: return False - else: - return NotImplemented + return True - if op == Py_EQ: - if len(first) != len(second): - return False + +cdef list_lt(object first, object second): + if not (isinstance(first, Iterable) and isinstance(second, Iterable)): + return NotImplemented for x, y in zip(first, second): if x != y: - if op == Py_LT: - return x < y - else: - return False + return x < y - if op == Py_LT: - return len(first) < len(second) - return True + return len(first) < len(second) @cython.internal @@ -318,24 +318,30 @@ cdef class List(Container): return type(other)(itertools.chain(other, self)) def __eq__(self, other): - return list_compare(self, other, Py_EQ) + return list_eq(self, other) def __ne__(self, other): - return not list_compare(self, other, Py_EQ) + return not list_eq(self, other) def __lt__(self, other): - return list_compare(self, other, Py_LT) + return list_lt(self, other) def __gt__(self, other): - return list_compare(other, self, Py_LT) + return list_lt(other, self) def __le__(self, other): - result = list_compare(other, self, Py_LT) - return not result if result is not NotImplemented else NotImplemented + result = list_lt(other, self) + if result is NotImplemented: + return NotImplemented + + return not result def __ge__(self, other): - result = list_compare(self, other, Py_LT) - return not result if result is not NotImplemented else NotImplemented + result = list_lt(self, other) + if result is NotImplemented: + return NotImplemented + + return not result def __repr__(self): if not self: