From e4d207bf62d48fc675f2e37a25adee8ee7e86400 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Tue, 9 Jan 2024 16:36:46 +0100 Subject: [PATCH] Fix E721: do not compare types, for exact checks use is / is not --- pytensor/compile/debugmode.py | 4 ++-- pytensor/compile/ops.py | 2 +- pytensor/graph/basic.py | 2 +- pytensor/graph/null_type.py | 2 +- pytensor/graph/rewriting/unify.py | 4 ++-- pytensor/graph/utils.py | 2 +- pytensor/ifelse.py | 2 +- pytensor/link/c/params_type.py | 4 ++-- pytensor/link/c/type.py | 2 +- pytensor/raise_op.py | 4 ++-- pytensor/scalar/basic.py | 6 +++--- pytensor/scalar/math.py | 10 +++++----- pytensor/scan/op.py | 2 +- pytensor/sparse/basic.py | 2 +- pytensor/tensor/random/type.py | 2 +- pytensor/tensor/rewriting/math.py | 2 +- pytensor/tensor/type.py | 4 ++-- pytensor/tensor/type_other.py | 2 +- pytensor/tensor/variable.py | 6 +++--- pytensor/typed_list/type.py | 2 +- tests/graph/rewriting/test_unify.py | 2 +- tests/graph/test_fg.py | 4 ++-- tests/graph/test_op.py | 2 +- tests/link/c/test_basic.py | 2 +- tests/sparse/test_basic.py | 2 +- tests/tensor/test_basic.py | 2 +- tests/tensor/test_subtensor.py | 2 +- 27 files changed, 41 insertions(+), 41 deletions(-) diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index da52f33761..92f4865e69 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -687,7 +687,7 @@ def _lessbroken_deepcopy(a): else: rval = copy.deepcopy(a) - assert type(rval) == type(a), (type(rval), type(a)) + assert type(rval) is type(a), (type(rval), type(a)) if isinstance(rval, np.ndarray): assert rval.dtype == a.dtype @@ -1154,7 +1154,7 @@ def __str__(self): return str(self.__dict__) def __eq__(self, other): - rval = type(self) == type(other) + rval = type(self) is type(other) if rval: # nodes are not compared because this comparison is # supposed to be true for corresponding events that happen diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 18f25d6078..170ea399cd 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape): self.infer_shape = self._infer_shape def __eq__(self, other): - return type(self) == type(other) and self.__fn == other.__fn + return type(self) is type(other) and self.__fn == other.__fn def __hash__(self): return hash(type(self)) ^ hash(self.__fn) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index b651b218d4..70e36ab60c 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -748,7 +748,7 @@ def __eq__(self, other): return True return ( - type(self) == type(other) + type(self) is type(other) and self.id == other.id and self.type == other.type ) diff --git a/pytensor/graph/null_type.py b/pytensor/graph/null_type.py index d2a77c67df..66f5c18fd1 100644 --- a/pytensor/graph/null_type.py +++ b/pytensor/graph/null_type.py @@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True): raise ValueError("NullType has no values to compare") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index d3f802af49..e9361d62c2 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -57,8 +57,8 @@ def __new__(cls, constraint, token=None, prefix=""): return obj def __eq__(self, other): - if type(self) == type(other): - return self.token == other.token and self.constraint == other.constraint + if type(self) is type(other): + return self.token is other.token and self.constraint == other.constraint return NotImplemented def __hash__(self): diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index 9a25cd6d8b..29f8327632 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -229,7 +229,7 @@ def __hash__(self): if "__eq__" not in dct: def __eq__(self, other): - return type(self) == type(other) and tuple( + return type(self) is type(other) and tuple( getattr(self, a) for a in props ) == tuple(getattr(other, a) for a in props) diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 1383cea263..2960c13d4d 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None): self.name = name def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False if self.as_view != other.as_view: return False diff --git a/pytensor/link/c/params_type.py b/pytensor/link/c/params_type.py index d0f09d82b7..6e8710b12f 100644 --- a/pytensor/link/c/params_type.py +++ b/pytensor/link/c/params_type.py @@ -301,7 +301,7 @@ def __hash__(self): def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self.__params_type__ == other.__params_type__ and all( # NB: Params object should have been already filtered. @@ -435,7 +435,7 @@ def __repr__(self): def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self.fields == other.fields and self.types == other.types ) diff --git a/pytensor/link/c/type.py b/pytensor/link/c/type.py index 10e7a166e0..f6ee5600d4 100644 --- a/pytensor/link/c/type.py +++ b/pytensor/link/c/type.py @@ -515,7 +515,7 @@ def __hash__(self): def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self.ctype == other.ctype and len(self) == len(other) and len(self.aliases) == len(other.aliases) diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index 1da5cc24dd..554e8f9b4c 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -16,7 +16,7 @@ class ExceptionType(Generic): def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -51,7 +51,7 @@ def __str__(self): return f"CheckAndRaise{{{self.exc_type}({self.msg})}}" def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False if self.msg == other.msg and self.exc_type == other.exc_type: diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 56a3629dc5..2a3db168ba 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1074,7 +1074,7 @@ def __call__(self, *types): return [rval] def __eq__(self, other): - return type(self) == type(other) and self.tbl == other.tbl + return type(self) is type(other) and self.tbl == other.tbl def __hash__(self): return hash(type(self)) # ignore hash of table @@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients): return self.grad(inputs, output_gradients) def __eq__(self, other): - test = type(self) == type(other) and getattr( + test = type(self) is type(other) and getattr( self, "output_types_preference", None ) == getattr(other, "output_types_preference", None) return test @@ -4133,7 +4133,7 @@ def __eq__(self, other): if self is other: return True if ( - type(self) != type(other) + type(self) is not type(other) or self.nin != other.nin or self.nout != other.nout ): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index edf03a393d..ac66fbd698 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -679,7 +679,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -732,7 +732,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -1045,7 +1045,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -1083,7 +1083,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 445b55e13f..4235220e81 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1246,7 +1246,7 @@ def is_cpu_vector(s): return apply_node def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False if self.info != other.info: diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 7c89d81cb5..957b96037e 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -462,7 +462,7 @@ def __eq__(self, other): return ( a == x and (b.dtype == y.dtype) - and (type(b) == type(y)) + and (type(b) is type(y)) and (b.shape == y.shape) and (abs(b - y).sum() < 1e-6 * b.nnz) ) diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 7f2a156271..88d5e6197f 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -107,7 +107,7 @@ def _eq(sa, sb): return _eq(sa, sb) def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index eb2bd31770..630ba87900 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1742,7 +1742,7 @@ def local_reduce_broadcastable(fgraph, node): ii += 1 new_reduced = reduced.dimshuffle(*pattern) if new_axis: - if type(node.op) == CAReduce: + if type(node.op) is CAReduce: # This case handles `CAReduce` instances # (e.g. generated by `scalar_elemwise`), and not the # scalar `Op`-specific subclasses diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index b55d226471..730ae9b07b 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -370,7 +370,7 @@ def values_eq_approx( return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol) def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return NotImplemented return other.dtype == self.dtype and other.shape == self.shape @@ -624,7 +624,7 @@ def c_code_cache_version(self): class DenseTypeMeta(MetaType): def __instancecheck__(self, o): - if type(o) == TensorType or isinstance(o, DenseTypeMeta): + if type(o) is TensorType or isinstance(o, DenseTypeMeta): return True return False diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 5704b43859..593204b1ef 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -64,7 +64,7 @@ def __str__(self): return "slice" def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index e881331017..ca66689d2f 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -945,7 +945,7 @@ class TensorConstantSignature(tuple): """ def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False try: (t0, d0), (t1, d1) = self, other @@ -1105,7 +1105,7 @@ def __deepcopy__(self, memo): class DenseVariableMeta(MetaType): def __instancecheck__(self, o): - if type(o) == TensorVariable or isinstance(o, DenseVariableMeta): + if type(o) is TensorVariable or isinstance(o, DenseVariableMeta): return True return False @@ -1120,7 +1120,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): class DenseConstantMeta(MetaType): def __instancecheck__(self, o): - if type(o) == TensorConstant or isinstance(o, DenseConstantMeta): + if type(o) is TensorConstant or isinstance(o, DenseConstantMeta): return True return False diff --git a/pytensor/typed_list/type.py b/pytensor/typed_list/type.py index 863849cbbe..9b842e9f4e 100644 --- a/pytensor/typed_list/type.py +++ b/pytensor/typed_list/type.py @@ -55,7 +55,7 @@ def __eq__(self, other): Two lists are equal if they contain the same type. """ - return type(self) == type(other) and self.ttype == other.ttype + return type(self) is type(other) and self.ttype == other.ttype def __hash__(self): return hash((type(self), self.ttype)) diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index a152fcee17..da430a1587 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -42,7 +42,7 @@ def perform(self, node, inputs, outputs): class CustomOpNoProps(CustomOpNoPropsNoEq): def __eq__(self, other): - return type(self) == type(other) and self.a == other.a + return type(self) is type(other) and self.a == other.a def __hash__(self): return hash((type(self), self.a)) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 1d2af0c7f0..9d08ec2fa0 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -31,8 +31,8 @@ def test_pickle(self): s = pickle.dumps(func) new_func = pickle.loads(s) - assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs)) - assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs)) + assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs)) + assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all( type(a.op) is type(b.op) # noqa: E721 for a, b in zip(func.apply_nodes, new_func.apply_nodes) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 73f612c2f5..5ec545015b 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -25,7 +25,7 @@ def __init__(self, thingy): self.thingy = thingy def __eq__(self, other): - return type(other) == type(self) and other.thingy == self.thingy + return type(other) is type(self) and other.thingy == self.thingy def __str__(self): return str(self.thingy) diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index 124763c2bd..ffbec1a533 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -71,7 +71,7 @@ def c_code_cache_version(self): return (1,) def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index e252df2887..b065f4342a 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -348,7 +348,7 @@ def __init__(self, structured): self.structured = structured def __eq__(self, other): - return (type(self) == type(other)) and self.structured == other.structured + return (type(self) is type(other)) and self.structured == other.structured def __hash__(self): return hash(type(self)) ^ hash(self.structured) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 8a014396a5..8ccffbd3bd 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3163,7 +3163,7 @@ def test_stack(): sx, sy = dscalar(), dscalar() rval = inplace_func([sx, sy], stack([sx, sy]))(-4.0, -2.0) - assert type(rval) == np.ndarray + assert type(rval) is np.ndarray assert [-4, -2] == list(rval) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 50fed4c61a..84f5a58022 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -819,7 +819,7 @@ def test_ok_list(self): assert np.allclose(val, good), (val, good) # Test reuse of output memory - if type(AdvancedSubtensor1) == AdvancedSubtensor1: + if type(AdvancedSubtensor1) is AdvancedSubtensor1: op = AdvancedSubtensor1() # When idx is a TensorConstant. if hasattr(idx, "data"):