From fb180f035fffa69311bec856e04afc8f30da5d47 Mon Sep 17 00:00:00 2001 From: Andrew Hilger Date: Mon, 28 Oct 2024 08:28:59 -0700 Subject: [PATCH] List.__eq__ doesn't raise on non-Sized Summary: In general, comparison operators shouldn't raise on a type mismatch; instead `__eq__` should return `False` and `__ne__` should return `True`. The behavior for `__lt__` etc. is ambiguous, so leaving that unchanged. This diff adds a `Sized` check to return `False` early in equality check. It also splits the `list_compare` method into `list_eq` and `list_lt` to make these easier to understand and improve performance (e.g., less branching inside `for` loop). V2: had # / buildall Reviewed By: yoney Differential Revision: D64998005 fbshipit-source-id: d99e45784b4e16d2b45e434ef30e89a167ece778 --- thrift/lib/py3/test/auto_migrate/lists.py | 8 ++++ thrift/lib/py3/types.pyx | 56 +++++++++++++---------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/thrift/lib/py3/test/auto_migrate/lists.py b/thrift/lib/py3/test/auto_migrate/lists.py index 1b309316c98..caa4a0c7e04 100644 --- a/thrift/lib/py3/test/auto_migrate/lists.py +++ b/thrift/lib/py3/test/auto_migrate/lists.py @@ -165,6 +165,14 @@ def test_comparisons(self) -> None: # got `List__i32`. self.assertGreaterEqual(x, x2) + @brokenInAutoMigrate() + def test_no_raise_on_type_error(self) -> None: + t_list = I32List([1, 2, 3, 4]) + self.assertFalse(t_list == easy()) + self.assertFalse(easy() == t_list) + self.assertNotEqual(t_list, easy()) + self.assertNotEqual(easy(), t_list) + @brokenInAutoMigrate() def test_is_container(self) -> None: self.assertIsInstance(int_list, Container) diff --git a/thrift/lib/py3/types.pyx b/thrift/lib/py3/types.pyx index c20cf7983a5..c69d54a8a85 100644 --- a/thrift/lib/py3/types.pyx +++ b/thrift/lib/py3/types.pyx @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from collections.abc import Iterable, Mapping, Set as pySet +from collections.abc import Iterable, Mapping, Set as pySet, Sized import enum import itertools import warnings @@ -46,29 +46,29 @@ __all__ = ['Struct', 'BadEnum', 'Union', 'Enum', 'Flag', 'EnumMeta'] Object = cython.fused_type(Struct, GeneratedError) +cdef list_eq(List self, object other): + if ( + not isinstance(other, Iterable) or + not isinstance(other, Sized) or + len(self) != len(other) + ): + return False -cdef list_compare(object first, object second, int op): - """ Take either Py_EQ or Py_LT, everything else is derived """ - if not (isinstance(first, Iterable) and isinstance(second, Iterable)): - if op == Py_EQ: + for x, y in zip(self, other): + if x != y: return False - else: - return NotImplemented + return True - if op == Py_EQ: - if len(first) != len(second): - return False + +cdef list_lt(object first, object second): + if not (isinstance(first, Iterable) and isinstance(second, Iterable)): + return NotImplemented for x, y in zip(first, second): if x != y: - if op == Py_LT: - return x < y - else: - return False + return x < y - if op == Py_LT: - return len(first) < len(second) - return True + return len(first) < len(second) @cython.internal @@ -318,24 +318,30 @@ cdef class List(Container): return type(other)(itertools.chain(other, self)) def __eq__(self, other): - return list_compare(self, other, Py_EQ) + return list_eq(self, other) def __ne__(self, other): - return not list_compare(self, other, Py_EQ) + return not list_eq(self, other) def __lt__(self, other): - return list_compare(self, other, Py_LT) + return list_lt(self, other) def __gt__(self, other): - return list_compare(other, self, Py_LT) + return list_lt(other, self) def __le__(self, other): - result = list_compare(other, self, Py_LT) - return not result if result is not NotImplemented else NotImplemented + result = list_lt(other, self) + if result is NotImplemented: + return NotImplemented + + return not result def __ge__(self, other): - result = list_compare(self, other, Py_LT) - return not result if result is not NotImplemented else NotImplemented + result = list_lt(self, other) + if result is NotImplemented: + return NotImplemented + + return not result def __repr__(self): if not self: