-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace pre-commit linters (flake8, isort, black, ...) with ruff #539
Conversation
Thanks @lmmx , appreciate the contributions! |
This looks really great! You may want to make sure that you have pre-commit correctly set up locally. In short, you need to run Also, we may be able to break up this PR into several smaller PRs like:
And possibly even the bleeding edge:
Note that the rules that Ruff enables by default are very minimal, so you'll want to add custom rules as appropriate. (Not all of these rules will be appropriate for PyTensor!) Note that the pyupgrade target version is configured here. And these lines make the output more verbose. Thanks for all your work on this! |
…nfig to what we have
…flake8`, and `autoflake`
Oops sorry @maresb I didn't see your comment there, that's basically what I was just working on! I asked in the ruff Discord getting-started channel how to do this and ruff's author Charlie Marsh gave me the specific suggestion for the config and indeed confirmed it would replace black, isort, and flake8: I've added a timing comparison (ignoring mypy the speedup is 270x (81s down to 0.3s). Yep I'm a regular user of pre-commit and have been running It's OK to leave the isort config in the project in case anyone remains using it locally (e.g. I have it set up to autorun on save in my vim config via ALE and need to switch to ruff myself!) In summary here's the result of running ruff check and format:
This touched the following files (which I confirmed are not in the exclude list of any of the existing linter TOML configs):
Click to show 57 files
I've dumped the diffs of this into the following block FYI, going to review what it did and why. Agree PR splitting may be in order. Click to show uncommited diff (2420 lines)diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py
index a11aa57bd..d86b19a12 100644
--- a/pytensor/graph/basic.py
+++ b/pytensor/graph/basic.py
@@ -62,6 +62,7 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
"""
+
name: Optional[str]
def get_parents(self):
diff --git a/pytensor/graph/destroyhandler.py b/pytensor/graph/destroyhandler.py
index 065c5da26..557a79fa7 100644
--- a/pytensor/graph/destroyhandler.py
+++ b/pytensor/graph/destroyhandler.py
@@ -366,9 +366,7 @@ class DestroyHandler(Bookkeeper): # noqa
OrderedSet()
) # set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation
- self.view_o = (
- {}
- ) # variable -> set of variables that use this one as a direct input
+ self.view_o = {} # variable -> set of variables that use this one as a direct input
# clients: how many times does an apply use a given variable
self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True
diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py
index 1c984ceb0..7fa005c21 100644
--- a/pytensor/graph/fg.py
+++ b/pytensor/graph/fg.py
@@ -6,14 +6,17 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import pytensor
from pytensor.configdefaults import config
-from pytensor.graph.basic import Apply, AtomicVariable, Variable, applys_between
-from pytensor.graph.basic import as_string as graph_as_string
from pytensor.graph.basic import (
+ Apply,
+ AtomicVariable,
+ Variable,
+ applys_between,
clone_get_equiv,
graph_inputs,
io_toposort,
vars_between,
)
+from pytensor.graph.basic import as_string as graph_as_string
from pytensor.graph.features import AlreadyThere, Feature, ReplaceValidate
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
from pytensor.misc.ordered_set import OrderedSet
diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py
index 4d4cfe02f..c24cd9935 100644
--- a/pytensor/graph/rewriting/basic.py
+++ b/pytensor/graph/rewriting/basic.py
@@ -10,9 +10,8 @@ import time
import traceback
import warnings
from collections import UserList, defaultdict, deque
-from collections.abc import Iterable
+from collections.abc import Iterable, Sequence
from collections.abc import Iterable as IterableType
-from collections.abc import Sequence
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py
index f7a03962d..31be791a3 100644
--- a/pytensor/link/c/basic.py
+++ b/pytensor/link/c/basic.py
@@ -120,9 +120,7 @@ def failure_code(sub, use_goto=True):
"Unexpected error in an Op's C code. "
"No Python exception was set.");
}
- %(goto_statement)s}""" % dict(
- sub, goto_statement=goto_statement
- )
+ %(goto_statement)s}""" % dict(sub, goto_statement=goto_statement)
def failure_code_init(sub):
@@ -384,8 +382,7 @@ def get_c_init(fgraph, r, name, sub):
"""
py_%(name)s = Py_None;
{Py_XINCREF(py_%(name)s);}
- """
- % locals()
+ """ % locals()
)
return pre + r.type.c_init(name, sub)
@@ -420,13 +417,10 @@ def get_c_extract(fgraph, r, name, sub):
else:
c_extract = r.type.c_extract(name, sub, False)
- pre = (
- """
+ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
- """
- % locals()
- )
+ """ % locals()
return pre + c_extract
@@ -452,13 +446,10 @@ def get_c_extract_out(fgraph, r, name, sub):
else:
c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False)
- pre = (
- """
+ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
- """
- % locals()
- )
+ """ % locals()
return pre + c_extract
@@ -467,12 +458,9 @@ def get_c_cleanup(fgraph, r, name, sub):
Wrapper around c_cleanup that decrefs py_name.
"""
- post = (
- """
+ post = """
{Py_XDECREF(py_%(name)s);}
- """
- % locals()
- )
+ """ % locals()
return r.type.c_cleanup(name, sub) + post
@@ -489,9 +477,7 @@ def get_c_sync(fgraph, r, name, sub):
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
{Py_XDECREF(old);}
}
- """ % dict(
- sync=r.type.c_sync(name, sub), name=name, **sub
- )
+ """ % dict(sync=r.type.c_sync(name, sub), name=name, **sub)
def apply_policy(fgraph, policy, r, name, sub):
@@ -1595,9 +1581,7 @@ class CLinker(Linker):
{struct_name} *self = ({struct_name} *)PyCapsule_GetContext(capsule);
delete self;
}}
- """.format(
- struct_name=self.struct_name
- )
+ """.format(struct_name=self.struct_name)
# We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code() + self.c_support_code_apply:
diff --git a/pytensor/link/c/cmodule.py b/pytensor/link/c/cmodule.py
index 1f4815043..aee0dbacc 100644
--- a/pytensor/link/c/cmodule.py
+++ b/pytensor/link/c/cmodule.py
@@ -56,7 +56,7 @@ class StdLibDirsAndLibsType(Protocol):
def is_StdLibDirsAndLibsType(
- fn: Callable[[], Optional[tuple[list[str], ...]]]
+ fn: Callable[[], Optional[tuple[list[str], ...]]],
) -> StdLibDirsAndLibsType:
return cast(StdLibDirsAndLibsType, fn)
diff --git a/pytensor/link/c/params_type.py b/pytensor/link/c/params_type.py
index ffa57b094..71500a9ab 100644
--- a/pytensor/link/c/params_type.py
+++ b/pytensor/link/c/params_type.py
@@ -823,9 +823,7 @@ class ParamsType(CType):
def c_declare(self, name, sub, check_input=True):
return """
%(struct_name)s* %(name)s;
- """ % dict(
- struct_name=self.name, name=name
- )
+ """ % dict(struct_name=self.name, name=name)
def c_init(self, name, sub):
# NB: It seems c_init() is not called for an op param.
diff --git a/pytensor/link/c/type.py b/pytensor/link/c/type.py
index 24ced701e..2cd524daa 100644
--- a/pytensor/link/c/type.py
+++ b/pytensor/link/c/type.py
@@ -98,15 +98,12 @@ class Generic(CType, Singleton):
"""
def c_sync(self, name, sub):
- return (
- """
+ return """
assert(py_%(name)s->ob_refcnt > 1);
Py_DECREF(py_%(name)s);
py_%(name)s = %(name)s ? %(name)s : Py_None;
Py_INCREF(py_%(name)s);
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (1,)
@@ -195,9 +192,7 @@ class CDataType(CType[D]):
def c_declare(self, name, sub, check_input=True):
return """
%(ctype)s %(name)s;
- """ % dict(
- ctype=self.ctype, name=name
- )
+ """ % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub):
return f"{name} = NULL;"
@@ -206,9 +201,7 @@ class CDataType(CType[D]):
return """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL);
if (%(name)s == NULL) %(fail)s
- """ % dict(
- name=name, ctype=self.ctype, fail=sub["fail"]
- )
+ """ % dict(name=name, ctype=self.ctype, fail=sub["fail"])
def c_sync(self, name, sub):
freefunc = self.freefunc
@@ -640,9 +633,7 @@ class EnumType(CType, dict):
if (PyErr_Occurred()) {
%(fail)s
}
- """ % dict(
- ctype=self.ctype, name=name, fail=sub["fail"]
- )
+ """ % dict(ctype=self.ctype, name=name, fail=sub["fail"])
def c_code_cache_version(self):
return (2, self.ctype, self.cname, tuple(self.items()))
diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py
index 7750607dc..7d9532557 100644
--- a/pytensor/link/jax/dispatch/elemwise.py
+++ b/pytensor/link/jax/dispatch/elemwise.py
@@ -30,7 +30,13 @@ def jax_funcify_CAReduce(op, **kwargs):
acc_dtype = getattr(op, "acc_dtype", None)
def careduce(x):
- nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype
+ nonlocal \
+ axis, \
+ op_nfunc_spec, \
+ scalar_nfunc_spec, \
+ scalar_op_name, \
+ scalar_op_identity, \
+ acc_dtype
if axis is None:
axis = list(range(x.ndim))
diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py
index 470df1eed..239eb2e00 100644
--- a/pytensor/link/numba/dispatch/elemwise.py
+++ b/pytensor/link/numba/dispatch/elemwise.py
@@ -39,9 +39,9 @@ from pytensor.scalar.basic import (
ScalarMinimum,
Sub,
TrueDiv,
+ scalar_maximum,
)
from pytensor.scalar.basic import add as add_as
-from pytensor.scalar.basic import scalar_maximum
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
diff --git a/pytensor/misc/ordered_set.py b/pytensor/misc/ordered_set.py
index 7b4880311..e59b43dac 100644
--- a/pytensor/misc/ordered_set.py
+++ b/pytensor/misc/ordered_set.py
@@ -68,6 +68,7 @@ class Link:
class OrderedSet(MutableSet):
"Set the remembers the order elements were added"
+
# Big-O running times for all methods are the same as for regular sets.
# The internal self.__map dictionary maps keys to links in a doubly linked list.
# The circular doubly linked list starts and ends with a sentinel element.
diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py
index 5b5ae0583..e3d4befef 100644
--- a/pytensor/scalar/basic.py
+++ b/pytensor/scalar/basic.py
@@ -499,9 +499,7 @@ class ScalarType(CType, HasDataType, HasShape):
%(fail)s
}
PyArrayScalar_ASSIGN(py_%(name)s, %(cls)s, %(name)s);
- """ % dict(
- sub, name=name, dtype=specs[1], cls=specs[2]
- )
+ """ % dict(sub, name=name, dtype=specs[1], cls=specs[2])
def c_cleanup(self, name, sub):
return ""
diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py
index 1f326b3fa..ce5285219 100644
--- a/pytensor/scalar/math.py
+++ b/pytensor/scalar/math.py
@@ -14,9 +14,10 @@ import scipy.stats
from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented, grad_undefined
-from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
-from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.basic import (
+ BinaryScalarOp,
+ ScalarOp,
+ UnaryScalarOp,
as_scalar,
complex_types,
constant,
@@ -42,6 +43,7 @@ from pytensor.scalar.basic import (
upgrade_to_float64,
upgrade_to_float_no_complex,
)
+from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.loop import ScalarLoop
@@ -618,11 +620,8 @@ class Chi2SF(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
- return (
- """%(z)s =
- (%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);"""
- % locals()
- )
+ return """%(z)s =
+ (%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);""" % locals()
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
@@ -667,11 +666,8 @@ class GammaInc(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
- return (
- """%(z)s =
- (%(dtype)s) GammaP(%(k)s, %(x)s);"""
- % locals()
- )
+ return """%(z)s =
+ (%(dtype)s) GammaP(%(k)s, %(x)s);""" % locals()
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
@@ -716,11 +712,8 @@ class GammaIncC(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
- return (
- """%(z)s =
- (%(dtype)s) GammaQ(%(k)s, %(x)s);"""
- % locals()
- )
+ return """%(z)s =
+ (%(dtype)s) GammaQ(%(k)s, %(x)s);""" % locals()
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
@@ -967,11 +960,8 @@ class GammaU(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
- return (
- """%(z)s =
- (%(dtype)s) upperGamma(%(k)s, %(x)s);"""
- % locals()
- )
+ return """%(z)s =
+ (%(dtype)s) upperGamma(%(k)s, %(x)s);""" % locals()
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
@@ -1008,11 +998,8 @@ class GammaL(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
- return (
- """%(z)s =
- (%(dtype)s) lowerGamma(%(k)s, %(x)s);"""
- % locals()
- )
+ return """%(z)s =
+ (%(dtype)s) lowerGamma(%(k)s, %(x)s);""" % locals()
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py
index b84cfb144..17f6dbcfa 100644
--- a/pytensor/sparse/basic.py
+++ b/pytensor/sparse/basic.py
@@ -27,13 +27,21 @@ from pytensor.sparse.type import SparseTensorType, _is_sparse
from pytensor.sparse.utils import hash_from_sparse
from pytensor.tensor import basic as at
from pytensor.tensor.basic import Split
-from pytensor.tensor.math import _conj
-from pytensor.tensor.math import add as at_add
-from pytensor.tensor.math import arcsin, arcsinh, arctan, arctanh, ceil, deg2rad
-from pytensor.tensor.math import dot as at_dot
-from pytensor.tensor.math import exp, expm1, floor, log, log1p, maximum, minimum
-from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import (
+ _conj,
+ arcsin,
+ arcsinh,
+ arctan,
+ arctanh,
+ ceil,
+ deg2rad,
+ exp,
+ expm1,
+ floor,
+ log,
+ log1p,
+ maximum,
+ minimum,
rad2deg,
round_half_to_even,
sigmoid,
@@ -46,11 +54,13 @@ from pytensor.tensor.math import (
tanh,
trunc,
)
+from pytensor.tensor.math import add as at_add
+from pytensor.tensor.math import dot as at_dot
+from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.shape import shape, specify_broadcastable
-from pytensor.tensor.type import TensorType
+from pytensor.tensor.type import TensorType, iscalar, ivector, scalar, tensor, vector
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
-from pytensor.tensor.type import iscalar, ivector, scalar, tensor, vector
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
@@ -3688,9 +3698,7 @@ class StructuredDotGradCSC(COp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
@@ -3824,9 +3832,7 @@ class StructuredDotGradCSR(COp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py
index 47ea1284b..c07396273 100644
--- a/pytensor/sparse/rewriting.py
+++ b/pytensor/sparse/rewriting.py
@@ -179,9 +179,7 @@ class AddSD_ccode(_NoPythonCOp):
}
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
return code
def infer_shape(self, fgraph, node, shapes):
@@ -432,9 +430,7 @@ class StructuredDotCSC(COp):
}
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
return rval
@@ -613,9 +609,7 @@ class StructuredDotCSR(COp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def c_code_cache_version(self):
return (2,)
@@ -894,9 +888,7 @@ class UsmmCscDense(_NoPythonCOp):
}
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
return rval
@@ -1087,9 +1079,7 @@ class CSMGradC(_NoPythonCOp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def c_code_cache_version(self):
return (3,)
@@ -1241,9 +1231,7 @@ class MulSDCSC(_NoPythonCOp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def __str__(self):
return self.__class__.__name__
@@ -1380,9 +1368,7 @@ class MulSDCSR(_NoPythonCOp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def __str__(self):
return self.__class__.__name__
@@ -1567,9 +1553,7 @@ class MulSVCSR(_NoPythonCOp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def __str__(self):
return self.__class__.__name__
@@ -1748,9 +1732,7 @@ class StructuredAddSVCSR(_NoPythonCOp):
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def __str__(self):
return self.__class__.__name__
@@ -2042,9 +2024,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
}
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
return rval
diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py
index 434f8b85e..645e1b505 100644
--- a/pytensor/tensor/basic.py
+++ b/pytensor/tensor/basic.py
@@ -592,15 +592,12 @@ class TensorFromScalar(COp):
(z,) = outputs
fail = sub["fail"]
- return (
- """
+ return """
%(z)s = (PyArrayObject*)PyArray_FromScalar(py_%(x)s, NULL);
if(%(z)s == NULL){
%(fail)s;
}
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (2,)
@@ -645,12 +642,9 @@ class ScalarFromTensor(COp):
(x,) = inputs
(z,) = outputs
fail = sub["fail"]
- return (
- """
+ return """
%(z)s = ((dtype_%(x)s*)(PyArray_DATA(%(x)s)))[0];
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (1,)
@@ -1397,7 +1391,7 @@ def register_infer_shape(rewrite, *tags, **kwargs):
def infer_static_shape(
- shape: Union[Variable, Sequence[Union[Variable, int]]]
+ shape: Union[Variable, Sequence[Union[Variable, int]]],
) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.
@@ -1807,24 +1801,18 @@ class MakeVector(COp):
assert self.dtype == node.inputs[0].dtype
out_num = f"PyArray_TYPE({inp[0]})"
- ret = (
- """
+ ret = """
npy_intp dims[1];
dims[0] = %(out_shape)s;
if(!%(out)s || PyArray_DIMS(%(out)s)[0] != %(out_shape)s){
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_num)s, 0);
}
- """
- % locals()
- )
+ """ % locals()
for idx, i in enumerate(inp):
- ret += (
- """
+ ret += """
*((%(out_dtype)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((%(out_dtype)s *) PyArray_DATA(%(i)s));
- """
- % locals()
- )
+ """ % locals()
return ret
def infer_shape(self, fgraph, node, ishapes):
@@ -2141,8 +2129,7 @@ class Split(COp):
splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits
- return (
- """
+ return """
int ndim = PyArray_NDIM(%(x)s);
int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
int splits_count = PyArray_DIM(%(splits)s, 0);
@@ -2239,9 +2226,7 @@ class Split(COp):
}
free(split_dims);
- """
- % locals()
- )
+ """ % locals()
class Join(COp):
@@ -2450,8 +2435,7 @@ class Join(COp):
copy_inputs_to_list = "\n".join(copy_to_list)
n = len(tens)
- code = (
- """
+ code = """
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
PyObject* list = PyList_New(%(l)s);
%(copy_inputs_to_list)s
@@ -2483,9 +2467,7 @@ class Join(COp):
%(fail)s
}
}
- """
- % locals()
- )
+ """ % locals()
return code
def R_op(self, inputs, eval_points):
@@ -4062,8 +4044,7 @@ class AllocEmpty(COp):
for idx, sh in enumerate(shps):
str += f"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]"
- str += (
- """){
+ str += """){
/* Reference received to invalid output variable.
Decrease received reference's ref count and allocate new
output variable */
@@ -4078,9 +4059,7 @@ class AllocEmpty(COp):
%(fail)s;
}
}
- """
- % locals()
- )
+ """ % locals()
return str
def infer_shape(self, fgraph, node, input_shapes):
diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py
index 78a80bd32..d8f101daa 100644
--- a/pytensor/tensor/blas.py
+++ b/pytensor/tensor/blas.py
@@ -1844,8 +1844,7 @@ class BatchedDot(COp):
)
z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim)
- allocate = (
- """
+ allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{
npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
@@ -1858,9 +1857,7 @@ class BatchedDot(COp):
%(fail)s
}
}
- """
- % locals()
- )
+ """ % locals()
# code to reallocate inputs contiguously if necessary
contiguate = []
@@ -1886,8 +1883,7 @@ class BatchedDot(COp):
"1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis)
for axis in shape
)
- return (
- """{
+ return """{
npy_intp dims[3] = {%(_shape)s};
PyArray_Dims newshape = {dims, 3};
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
@@ -1895,9 +1891,7 @@ class BatchedDot(COp):
%(_fail)s
// make sure we didn't accidentally copy
assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
- }"""
- % locals()
- )
+ }""" % locals()
# create tensor3 views for any of x, y, z that are not tensor3, so that
# we only need to implement the tensor3-tensor3 batched dot product.
@@ -1927,8 +1921,7 @@ class BatchedDot(COp):
)
upcast = "\n".join(upcast) % locals()
- return (
- """
+ return """
int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
@@ -1992,9 +1985,7 @@ class BatchedDot(COp):
}
break;
}
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version
diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py
index 704970b5e..ef6d56d39 100644
--- a/pytensor/tensor/blas_c.py
+++ b/pytensor/tensor/blas_c.py
@@ -33,8 +33,7 @@ class BaseBLAS(COp):
def ger_c_code(A, a, x, y, Z, fail, params):
- return (
- """
+ return """
int elemsize ;
@@ -310,9 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params):
}
}
- """
- % locals()
- )
+ """ % locals()
class CGer(BaseBLAS, Ger):
diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py
index 26281c69e..b68db0efe 100644
--- a/pytensor/tensor/blas_headers.py
+++ b/pytensor/tensor/blas_headers.py
@@ -1075,8 +1075,7 @@ def blas_header_version():
def ____gemm_code(check_ab, a_init, b_init):
mod = "%"
- return (
- """
+ return """
const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num;
@@ -1213,6 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
/* v 1 */
- """
- % locals()
- )
+ """ % locals()
diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py
index 24e8e5be4..f7ed120d7 100644
--- a/pytensor/tensor/conv/abstract_conv.py
+++ b/pytensor/tensor/conv/abstract_conv.py
@@ -2284,8 +2284,9 @@ class BaseAbstractConv(Op):
"""
if mode not in ("valid", "full"):
raise ValueError(
- "invalid mode {}, which must be either "
- '"valid" or "full"'.format(mode)
+ "invalid mode {}, which must be either " '"valid" or "full"'.format(
+ mode
+ )
)
if isinstance(dilation, int):
dilation = (dilation,) * self.convdim
@@ -2545,8 +2546,7 @@ class AbstractConv(BaseAbstractConv):
)
if kern.shape[1 : 1 + self.convdim] != out_shape[2 : 2 + self.convdim]:
raise ValueError(
- "Kernel shape {} does not match "
- "computed output size {}".format(
+ "Kernel shape {} does not match " "computed output size {}".format(
kern.shape[1 : 1 + self.convdim],
out_shape[2 : 2 + self.convdim],
)
diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py
index 869d40faa..dcf3f3330 100644
--- a/pytensor/tensor/elemwise.py
+++ b/pytensor/tensor/elemwise.py
@@ -926,16 +926,13 @@ class Elemwise(OpenMPOp):
# We make the output point to the corresponding input and
# decrease the reference of whatever the output contained
# prior to this
- alloc += (
- """
+ alloc += """
if (%(oname)s) {
Py_XDECREF(%(oname)s);
}
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
- """
- % locals()
- )
+ """ % locals()
# We alias the scalar variables
defines += f"#define {oname}_i {iname}_i\n"
undefs += f"#undef {oname}_i\n"
@@ -958,16 +955,13 @@ class Elemwise(OpenMPOp):
[f"{s}_i" for s in onames],
dict(sub, fail=fail),
)
- code = (
- """
+ code = """
{
%(defines)s
%(task_code)s
%(undefs)s
}
- """
- % locals()
- )
+ """ % locals()
loop_orders = orders + [list(range(nnested))] * len(real_onames)
dtypes = idtypes + list(real_odtypes)
@@ -1011,8 +1005,7 @@ class Elemwise(OpenMPOp):
) % sub
init_array = preloops.get(0, " ")
- loop = (
- """
+ loop = """
{
%(defines)s
%(init_array)s
@@ -1020,9 +1013,7 @@ class Elemwise(OpenMPOp):
%(task_code)s
%(undefs)s
}
- """
- % locals()
- )
+ """ % locals()
else:
loop = cgen.make_loop(
loop_orders=loop_orders,
@@ -1082,37 +1073,25 @@ class Elemwise(OpenMPOp):
index = ""
for x, var in zip(inames + onames, inputs + node.outputs):
if not all(s == 1 for s in var.type.shape):
- contig += (
- """
+ contig += """
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
- """
- % locals()
- )
- index += (
- """
+ """ % locals()
+ index += """
dtype_%(x)s& %(x)s_i = %(x)s_ptr[i];
- """
- % locals()
- )
+ """ % locals()
else:
- contig += (
- """
+ contig += """
dtype_%(x)s& %(x)s_i = ((dtype_%(x)s*) PyArray_DATA(%(x)s))[0];
- """
- % locals()
- )
+ """ % locals()
if self.openmp:
contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
- contig += (
- """
+ contig += """
for(int i=0; i<n; i++){
%(index)s
%(task_code)s;
}
- """
- % locals()
- )
+ """ % locals()
if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape)
@@ -1130,16 +1109,13 @@ class Elemwise(OpenMPOp):
if not all_broadcastable
]
)
- loop = (
- """
+ loop = """
if((%(cond1)s) || (%(cond2)s)){
%(contig)s
}else{
%(loop)s
}
- """
- % locals()
- )
+ """ % locals()
return decl, checks, alloc, loop, ""
def c_code(self, node, nodename, inames, onames, sub):
diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py
index 85d4a93c7..d8b9700a5 100644
--- a/pytensor/tensor/elemwise_cgen.py
+++ b/pytensor/tensor/elemwise_cgen.py
@@ -203,9 +203,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
%(fail)s
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py
index 8f4158696..9653826d4 100644
--- a/pytensor/tensor/extra_ops.py
+++ b/pytensor/tensor/extra_ops.py
@@ -28,11 +28,9 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
-from pytensor.tensor.math import ge, lt
+from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
from pytensor.tensor.math import max as pt_max
-from pytensor.tensor.math import maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import switch
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
@@ -70,8 +68,7 @@ class CpuContiguous(COp):
def c_code(self, node, name, inames, onames, sub):
(x,) = inames
(y,) = onames
- code = (
- """
+ code = """
if (!PyArray_CHKFLAGS(%(x)s, NPY_ARRAY_C_CONTIGUOUS)){
// check to see if output is contiguous first
if (%(y)s != NULL &&
@@ -89,9 +86,7 @@ class CpuContiguous(COp):
Py_XDECREF(%(y)s);
%(y)s = %(x)s;
}
- """
- % locals()
- )
+ """ % locals()
return code
def c_code_cache_version(self):
@@ -166,16 +161,13 @@ class SearchsortedOp(COp):
def c_init_code_struct(self, node, name, sub):
side = sub["params"]
fail = sub["fail"]
- return (
- """
+ return """
PyObject* tmp_%(name)s = PyUnicode_FromString("right");
if (tmp_%(name)s == NULL)
%(fail)s;
right_%(name)s = PyUnicode_Compare(%(side)s, tmp_%(name)s);
Py_DECREF(tmp_%(name)s);
- """
- % locals()
- )
+ """ % locals()
def c_code(self, node, name, inames, onames, sub):
sorter = None
@@ -188,8 +180,7 @@ class SearchsortedOp(COp):
(z,) = onames
fail = sub["fail"]
- return (
- """
+ return """
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SearchSorted(%(x)s, (PyObject*) %(v)s,
right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s);
@@ -200,9 +191,7 @@ class SearchsortedOp(COp):
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) tmp;
}
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (2,)
@@ -361,8 +350,7 @@ class CumOp(COp):
fail = sub["fail"]
params = sub["params"]
- code = (
- """
+ code = """
int axis = %(params)s->c_axis;
if (axis == 0 && PyArray_NDIM(%(x)s) == 1)
axis = NPY_MAXDIMS;
@@ -399,9 +387,7 @@ class CumOp(COp):
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}
- """
- % locals()
- )
+ """ % locals()
return code
diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py
index 0f035272a..82c088bdf 100644
--- a/pytensor/tensor/math.py
+++ b/pytensor/tensor/math.py
@@ -404,17 +404,14 @@ class Argmax(COp):
if len(self.axis) > 1:
raise NotImplementedError()
# params is only used here for now
- axis_code = (
- """
+ axis_code = """
axis = %(params)s->c_axis;
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
%(fail)s
}
- """
- % locals()
- )
+ """ % locals()
ret = """
int axis;
diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py
index 92f1ad3f2..5bf33356d 100644
--- a/pytensor/tensor/random/basic.py
+++ b/pytensor/tensor/random/basic.py
@@ -87,6 +87,7 @@ class UniformRV(RandomVariable):
\end{split}
"""
+
name = "uniform"
ndim_supp = 0
ndims_params = [0, 0]
@@ -141,6 +142,7 @@ class TriangularRV(RandomVariable):
\end{split}
"""
+
name = "triangular"
ndim_supp = 0
ndims_params = [0, 0, 0]
@@ -196,6 +198,7 @@ class BetaRV(RandomVariable):
B(\alpha, \beta) = \int_0^1 t^{\alpha-1} (1-t)^{\beta-1} \mathrm{d}t
"""
+
name = "beta"
ndim_supp = 0
ndims_params = [0, 0]
@@ -242,6 +245,7 @@ class NormalRV(RandomVariable):
for :math:`\sigma > 0`.
"""
+
name = "normal"
ndim_supp = 0
ndims_params = [0, 0]
@@ -322,6 +326,7 @@ class HalfNormalRV(ScipyRandomVariable):
for :math:`x \geq 0` and :math:`\sigma > 0`.
"""
+
name = "halfnormal"
ndim_supp = 0
ndims_params = [0, 0]
@@ -387,6 +392,7 @@ class LogNormalRV(RandomVariable):
for :math:`x > 0` and :math:`\sigma > 0`.
"""
+
name = "lognormal"
ndim_supp = 0
ndims_params = [0, 0]
@@ -438,6 +444,7 @@ class GammaRV(RandomVariable):
\Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t
"""
+
name = "gamma"
ndim_supp = 0
ndims_params = [0, 0]
@@ -533,6 +540,7 @@ class ParetoRV(ScipyRandomVariable):
and is defined for :math:`x \geq x_m`.
"""
+
name = "pareto"
ndim_supp = 0
ndims_params = [0, 0]
@@ -583,6 +591,7 @@ class GumbelRV(ScipyRandomVariable):
for :math:`\beta > 0`.
"""
+
name = "gumbel"
ndim_supp = 0
ndims_params = [0, 0]
@@ -644,6 +653,7 @@ class ExponentialRV(RandomVariable):
for :math:`x \geq 0` and :math:`\beta > 0`.
"""
+
name = "exponential"
ndim_supp = 0
ndims_params = [0]
@@ -687,6 +697,7 @@ class WeibullRV(RandomVariable):
for :math:`x \geq 0` and :math:`k > 0`.
"""
+
name = "weibull"
ndim_supp = 0
ndims_params = [0]
@@ -731,6 +742,7 @@ class LogisticRV(RandomVariable):
for :math:`s > 0`.
"""
+
name = "logistic"
ndim_supp = 0
ndims_params = [0, 0]
@@ -779,6 +791,7 @@ class VonMisesRV(RandomVariable):
function of order 0.
"""
+
name = "vonmises"
ndim_supp = 0
ndims_params = [0, 0]
@@ -846,6 +859,7 @@ class MvNormalRV(RandomVariable):
where :math:`\Sigma` is a positive semi-definite matrix.
"""
+
name = "multivariate_normal"
ndim_supp = 1
ndims_params = [1, 2]
@@ -932,6 +946,7 @@ class DirichletRV(RandomVariable):
:math:`\sum_{i=1}^k x_i = 1`.
"""
+
name = "dirichlet"
ndim_supp = 1
ndims_params = [1]
@@ -1005,6 +1020,7 @@ class PoissonRV(RandomVariable):
for :math:`\lambda > 0`.
"""
+
name = "poisson"
ndim_supp = 0
ndims_params = [0]
@@ -1050,6 +1066,7 @@ class GeometricRV(RandomVariable):
for :math:`0 \geq p \geq 1`.
"""
+
name = "geometric"
ndim_supp = 0
ndims_params = [0]
@@ -1092,6 +1109,7 @@ class HyperGeometricRV(RandomVariable):
f(k; n, N, K) = \frac{{K \choose k} {N-K \choose n-k}}{{N \choose n}}
"""
+
name = "hypergeometric"
ndim_supp = 0
ndims_params = [0, 0, 0]
@@ -1140,6 +1158,7 @@ class CauchyRV(ScipyRandomVariable):
where :math:`\gamma > 0`.
"""
+
name = "cauchy"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1190,6 +1209,7 @@ class HalfCauchyRV(ScipyRandomVariable):
for :math:`x \geq 0` where :math:`\gamma > 0`.
"""
+
name = "halfcauchy"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1244,6 +1264,7 @@ class InvGammaRV(ScipyRandomVariable):
\Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t
"""
+
name = "invgamma"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1294,6 +1315,7 @@ class WaldRV(RandomVariable):
for :math:`x > 0`, where :math:`\mu > 0` and :math:`\lambda > 0`.
"""
+
name = "wald"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1341,6 +1363,7 @@ class TruncExponentialRV(ScipyRandomVariable):
for :math:`0 \leq x \leq b` and :math:`\beta > 0`.
"""
+
name = "truncexpon"
ndim_supp = 0
ndims_params = [0, 0, 0]
@@ -1396,6 +1419,7 @@ class StudentTRV(ScipyRandomVariable):
for :math:`\nu > 0`, :math:`\sigma > 0`.
"""
+
name = "t"
ndim_supp = 0
ndims_params = [0, 0, 0]
@@ -1455,6 +1479,7 @@ class BernoulliRV(ScipyRandomVariable):
where :math:`0 \leq p \leq 1`.
"""
+
name = "bernoulli"
ndim_supp = 0
ndims_params = [0]
@@ -1502,6 +1527,7 @@ class LaplaceRV(RandomVariable):
with :math:`\lambda > 0`.
"""
+
name = "laplace"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1548,6 +1574,7 @@ class BinomialRV(RandomVariable):
f(k; p, n) = {n \choose k} p^k (1-p)^{n-k}
"""
+
name = "binomial"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1592,6 +1619,7 @@ class NegBinomialRV(ScipyRandomVariable):
f(k; p, n) = {k+n-1 \choose n-1} p^n (1-p)^{k}
"""
+
name = "nbinom"
ndim_supp = 0
ndims_params = [0, 0]
@@ -1647,6 +1675,7 @@ class BetaBinomialRV(ScipyRandomVariable):
\operatorname{B}(a, b) = \int_0^1 t^{a-1} (1-t)^{b-1} \mathrm{d}t
"""
+
name = "beta_binomial"
ndim_supp = 0
ndims_params = [0, 0, 0]
@@ -1698,6 +1727,7 @@ class GenGammaRV(ScipyRandomVariable):
for :math:`x > 0`, where :math:`\alpha, \lambda, p > 0`.
"""
+
name = "gengamma"
ndim_supp = 0
ndims_params = [0, 0, 0]
@@ -1760,6 +1790,7 @@ class MultinomialRV(RandomVariable):
dimension in the *second* parameter (i.e. the probabilities vector).
"""
+
name = "multinomial"
ndim_supp = 1
ndims_params = [0, 1]
@@ -1943,6 +1974,7 @@ class IntegersRV(RandomVariable):
Only available for `RandomGeneratorType`. Use `randint` with `RandomStateType`\s.
"""
+
name = "integers"
ndim_supp = 0
ndims_params = [0, 0]
diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py
index 893a93086..ba88d72f6 100644
--- a/pytensor/tensor/random/utils.py
+++ b/pytensor/tensor/random/utils.py
@@ -122,7 +122,7 @@ def broadcast_params(params, ndims_params):
def normalize_size_param(
- size: Optional[Union[int, np.ndarray, Variable, Sequence]]
+ size: Optional[Union[int, np.ndarray, Variable, Sequence]],
) -> Variable:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if size is None:
diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 021660d8e..9bd7efc40 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -67,9 +67,8 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
-from pytensor.tensor.math import Sum, add
+from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.math import all as at_all
-from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType
diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py
index 6f2fcd230..94778d789 100644
--- a/pytensor/tensor/rewriting/elemwise.py
+++ b/pytensor/tensor/rewriting/elemwise.py
@@ -690,7 +690,8 @@ class FusionOptimizer(GraphRewriter):
UNFUSEABLE_MAPPING = DefaultDict[Variable, set[ApplyOrOutput]]
def initialize_fuseable_mappings(
- *, fg: FunctionGraph
+ *,
+ fg: FunctionGraph,
) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]:
@lru_cache(maxsize=None)
def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
@@ -754,7 +755,7 @@ class FusionOptimizer(GraphRewriter):
VT = TypeVar("VT", list, set)
def shallow_clone_defaultdict(
- d: DefaultDict[KT, VT]
+ d: DefaultDict[KT, VT],
) -> DefaultDict[KT, VT]:
new_dict: DefaultDict[KT, VT] = defaultdict(d.default_factory)
new_dict.update({k: v.copy() for k, v in d.items()})
diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py
index a814ffdf6..409cdb18a 100644
--- a/pytensor/tensor/rewriting/math.py
+++ b/pytensor/tensor/rewriting/math.py
@@ -49,9 +49,6 @@ from pytensor.tensor.math import (
ProdWithoutZeros,
Sum,
_conj,
-)
-from pytensor.tensor.math import abs as at_abs
-from pytensor.tensor.math import (
add,
digamma,
dot,
@@ -68,11 +65,10 @@ from pytensor.tensor.math import (
log1mexp,
log1p,
makeKeepDims,
-)
-from pytensor.tensor.math import max as at_max
-from pytensor.tensor.math import maximum, mul, neg, polygamma
-from pytensor.tensor.math import pow as at_pow
-from pytensor.tensor.math import (
+ maximum,
+ mul,
+ neg,
+ polygamma,
prod,
reciprocal,
sigmoid,
@@ -81,9 +77,13 @@ from pytensor.tensor.math import (
sqr,
sqrt,
sub,
+ tri_gamma,
+ true_div,
)
+from pytensor.tensor.math import abs as at_abs
+from pytensor.tensor.math import max as at_max
+from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tri_gamma, true_div
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
diff --git a/pytensor/tensor/rewriting/special.py b/pytensor/tensor/rewriting/special.py
index c893439e4..8f5028e3d 100644
--- a/pytensor/tensor/rewriting/special.py
+++ b/pytensor/tensor/rewriting/special.py
@@ -1,8 +1,7 @@
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.elemwise import DimShuffle
-from pytensor.tensor.math import Sum, exp, log
+from pytensor.tensor.math import Sum, exp, log, true_div
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import register_stabilize
from pytensor.tensor.rewriting.math import local_mul_canonizer
from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax
diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py
index 4e80d3bb3..93cc62abe 100644
--- a/pytensor/tensor/rewriting/subtensor.py
+++ b/pytensor/tensor/rewriting/subtensor.py
@@ -31,9 +31,9 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
-from pytensor.tensor.math import Dot, add
-from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
+ Dot,
+ add,
and_,
ceil_intdiv,
dot,
@@ -46,6 +46,7 @@ from pytensor.tensor.math import (
minimum,
or_,
)
+from pytensor.tensor.math import all as at_all
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py
index 1d8efa02c..a84ff0be5 100644
--- a/pytensor/tensor/shape.py
+++ b/pytensor/tensor/shape.py
@@ -15,9 +15,8 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import int32
-from pytensor.tensor import _get_vector_length, as_tensor_variable
+from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as at
-from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst
diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py
index c05e965bf..9313c20d3 100644
--- a/pytensor/tensor/subtensor.py
+++ b/pytensor/tensor/subtensor.py
@@ -127,7 +127,7 @@ def indices_from_subtensor(
def as_index_constant(
- a: Optional[Union[slice, int, np.integer, Variable]]
+ a: Optional[Union[slice, int, np.integer, Variable]],
) -> Optional[Union[Variable, slice]]:
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
@@ -150,7 +150,7 @@ def as_index_constant(
def as_index_literal(
- idx: Optional[Union[Variable, slice]]
+ idx: Optional[Union[Variable, slice]],
) -> Optional[Union[int, slice]]:
"""Convert a symbolic index element to its Python equivalent.
@@ -1028,8 +1028,7 @@ class Subtensor(COp):
"""
- rval += (
- """
+ rval += """
// One more argument of the view
npy_intp xview_offset = 0;
@@ -1156,9 +1155,7 @@ class Subtensor(COp):
inner_ii += 1;
outer_ii += 1;
}
- """
- % locals()
- )
+ """ % locals()
# print rval
return rval
@@ -1178,23 +1175,19 @@ class Subtensor(COp):
decl = "PyArrayObject * xview = NULL;"
- checkNDim = (
- """
+ checkNDim = """
if (PyArray_NDIM(%(x)s) != %(ndim)s){
PyErr_SetString(PyExc_ValueError,
"Expected %(ndim)s dimensions input"
);
%(fail)s
}
- """
- % locals()
- )
+ """ % locals()
get_xview = self.helper_c_code(
node, name, inputs, outputs, sub, self.idx_list, view_ndim
)
- build_view = (
- """
+ build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(x)s));
@@ -1212,9 +1205,7 @@ class Subtensor(COp):
{
%(fail)s;
}
- """
- % locals()
- )
+ """ % locals()
finish_view = f"""
Py_XDECREF({z});
@@ -1646,8 +1637,7 @@ class IncSubtensor(COp):
copy_of_x = self.copy_of_x(x)
- copy_input_if_necessary = (
- """
+ copy_input_if_necessary = """
if (%(inplace)s)
{
if (%(x)s != %(z)s)
@@ -1666,9 +1656,7 @@ class IncSubtensor(COp):
%(fail)s
}
}
- """
- % locals()
- )
+ """ % locals()
# get info needed to make zview: a view of %(z)s
helper_args = self.get_helper_c_code_args()
@@ -1687,8 +1675,7 @@ class IncSubtensor(COp):
# Make a view on the output, as we will write into it.
alloc_zview = self.make_view_array(z, view_ndim)
- build_view = (
- """
+ build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
%(alloc_zview)s;
@@ -1696,16 +1683,13 @@ class IncSubtensor(COp):
{
%(fail)s;
}
- """
- % locals()
- )
+ """ % locals()
copy_into = self.copy_into("zview", y)
add_to_zview = self.add_to_zview(name, y, fail)
- make_modification = (
- """
+ make_modification = """
if (%(op_is_set)s)
{
if (%(copy_into)s) // does broadcasting
@@ -1718,9 +1702,7 @@ class IncSubtensor(COp):
{
%(add_to_zview)s
}
- """
- % locals()
- )
+ """ % locals()
return (
self.decl_view()
+ copy_input_if_necessary
@@ -1790,8 +1772,7 @@ class IncSubtensor(COp):
"""
- return (
- """Py_INCREF(PyArray_DESCR(%(x)s));
+ return """Py_INCREF(PyArray_DESCR(%(x)s));
zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
@@ -1801,9 +1782,7 @@ class IncSubtensor(COp):
PyArray_BYTES(%(x)s) + xview_offset, //PyArray_DATA(%(x)s),
PyArray_FLAGS(%(x)s),
NULL);
- """
- % locals()
- )
+ """ % locals()
def get_helper_c_code_args(self):
"""
@@ -1836,8 +1815,7 @@ class IncSubtensor(COp):
"""
- return (
- """
+ return """
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)zview, py_%(x)s);
if (add_rval)
@@ -1850,9 +1828,7 @@ class IncSubtensor(COp):
{
Py_DECREF(zview);
%(fail)s;
- }"""
- % locals()
- )
+ }""" % locals()
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
@@ -2069,8 +2045,7 @@ class AdvancedSubtensor1(COp):
a_name, i_name = input_names[0], input_names[1]
output_name = output_names[0]
fail = sub["fail"]
- return (
- """
+ return """
PyArrayObject *indices;
int i_type = PyArray_TYPE(%(i_name)s);
if (i_type != NPY_INTP) {
@@ -2145,9 +2120,7 @@ class AdvancedSubtensor1(COp):
%(a_name)s, (PyObject*)indices, 0, %(output_name)s, NPY_RAISE);
Py_DECREF(indices);
if (%(output_name)s == NULL) %(fail)s;
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (0, 1, 2)
diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py
index 7392d3f42..1c1146aed 100644
--- a/pytensor/tensor/type.py
+++ b/pytensor/tensor/type.py
@@ -477,25 +477,19 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
if check_input:
check = """
typedef %(dtype)s dtype_%(name)s;
- """ % dict(
- sub, name=name, dtype=self.dtype_specs()[1]
- )
+ """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
else:
check = ""
declaration = """
PyArrayObject* %(name)s;
- """ % dict(
- sub, name=name, dtype=self.dtype_specs()[1]
- )
+ """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
return declaration + check
def c_init(self, name, sub):
return """
%(name)s = NULL;
- """ % dict(
- sub, name=name, type_num=self.dtype_specs()[2]
- )
+ """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input:
@@ -547,9 +541,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
%(type_num)s, PyArray_TYPE((PyArrayObject*) py_%(name)s));
%(fail)s
}
- """ % dict(
- sub, name=name, type_num=self.dtype_specs()[2]
- )
+ """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
else:
check = ""
return (
@@ -562,20 +554,16 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
)
def c_cleanup(self, name, sub):
- return (
- """
+ return """
if (%(name)s) {
Py_XDECREF(%(name)s);
}
- """
- % locals()
- )
+ """ % locals()
def c_sync(self, name, sub):
fail = sub["fail"]
type_num = self.dtype_specs()[2]
- return (
- """
+ return """
{Py_XDECREF(py_%(name)s);}
if (!%(name)s) {
Py_INCREF(Py_None);
@@ -610,9 +598,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
);
%(fail)s
}
- """
- % locals()
- )
+ """ % locals()
def c_headers(self, **kwargs):
return aes.get_scalar_type(self.dtype).c_headers(**kwargs)
diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py
index ed0066338..b27a259e6 100644
--- a/pytensor/tensor/utils.py
+++ b/pytensor/tensor/utils.py
@@ -138,7 +138,7 @@ def import_func_from_string(func_string: str): # -> Optional[Callable]:
def broadcast_static_dim_lengths(
- dim_lengths: Sequence[Union[int, None]]
+ dim_lengths: Sequence[Union[int, None]],
) -> Union[int, None]:
"""Apply static broadcast given static dim length of inputs (obtained from var.type.shape).
diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py
index ab69af528..19880ff05 100644
--- a/pytensor/tensor/var.py
+++ b/pytensor/tensor/var.py
@@ -1,8 +1,8 @@
import warnings
-
from pytensor.tensor.variable import * # noqa
+
warnings.warn(
"The module 'pytensor.tensor.var' has been deprecated. "
"Use 'pytensor.tensor.variable' instead.",
diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py
index 54e41124b..5dc983b54 100644
--- a/pytensor/typed_list/basic.py
+++ b/pytensor/typed_list/basic.py
@@ -102,16 +102,13 @@ class GetItem(COp):
x_name, index = inp[0], inp[1]
output_name = out[0]
fail = sub["fail"]
- return (
- """
+ return """
%(output_name)s = (typeof %(output_name)s) PyList_GetItem( (PyObject*) %(x_name)s, *((npy_int64 *) PyArray_DATA(%(index)s)));
if(%(output_name)s == NULL){
%(fail)s
}
Py_INCREF(%(output_name)s);
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (1,)
@@ -172,12 +169,9 @@ class Append(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
- init = (
- """
+ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
- """
- % locals()
- )
+ """ % locals()
else:
init = f"""
{output_name} = {x_name};
@@ -257,12 +251,9 @@ class Extend(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
- init = (
- """
+ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
- """
- % locals()
- )
+ """ % locals()
else:
init = f"""
{output_name} = {x_name};
@@ -349,12 +340,9 @@ class Insert(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
- init = (
- """
+ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
- """
- % locals()
- )
+ """ % locals()
else:
init = f"""
{output_name} = {x_name};
@@ -481,12 +469,9 @@ class Reverse(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
- init = (
- """
+ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
- """
- % locals()
- )
+ """ % locals()
else:
init = f"""
{output_name} = {x_name};
@@ -616,15 +601,12 @@ class Length(COp):
x_name = inp[0]
output_name = out[0]
fail = sub["fail"]
- return (
- """
+ return """
if(!%(output_name)s)
%(output_name)s=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(%(output_name)s))[0]=PyList_Size((PyObject*)%(x_name)s);
Py_INCREF(%(output_name)s);
- """
- % locals()
- )
+ """ % locals()
def c_code_cache_version(self):
return (1,)
diff --git a/pytensor/typed_list/type.py b/pytensor/typed_list/type.py
index 7b4325265..136e1b6b2 100644
--- a/pytensor/typed_list/type.py
+++ b/pytensor/typed_list/type.py
@@ -114,9 +114,7 @@ class TypedListType(CType):
if (!PyList_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)s
- }""" % dict(
- name=name, fail=sub["fail"]
- )
+ }""" % dict(name=name, fail=sub["fail"])
else:
pre = ""
return (
@@ -132,9 +130,7 @@ class TypedListType(CType):
Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)(%(name)s);
Py_INCREF(py_%(name)s);
- """ % dict(
- name=name
- )
+ """ % dict(name=name)
def c_cleanup(self, name, sub):
return ""
diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py
index 59472b9cd..8611d0aa0 100644
--- a/tests/compile/function/test_types.py
+++ b/tests/compile/function/test_types.py
@@ -16,9 +16,8 @@ from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
-from pytensor.tensor.math import dot
+from pytensor.tensor.math import dot, tanh
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tanh
from pytensor.tensor.type import (
dmatrix,
dscalar,
diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py
index 8fdeb1847..337e1f783 100644
--- a/tests/compile/test_builders.py
+++ b/tests/compile/test_builders.py
@@ -16,9 +16,8 @@ from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from pytensor.tensor.basic import as_tensor
-from pytensor.tensor.math import dot, exp
+from pytensor.tensor.math import dot, exp, sigmoid
from pytensor.tensor.math import round as at_round
-from pytensor.tensor.math import sigmoid
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.rewriting.shape import ShapeOptimizer
diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py
index 0719f093b..4e29564dc 100644
--- a/tests/compile/test_debugmode.py
+++ b/tests/compile/test_debugmode.py
@@ -96,9 +96,7 @@ class BROKEN_ON_PURPOSE_Add(COp):
+ ((double*)PyArray_GETPTR1(%(b)s, m))[0] ;
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
# inconsistent is a invalid op, whose perform and c_code do not match
@@ -692,9 +690,7 @@ class BrokenCImplementationAdd(COp):
}
}
}
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
class VecAsRowAndCol(Op):
diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py
index b8ea0d340..34cdb86ea 100644
--- a/tests/link/c/test_basic.py
+++ b/tests/link/c/test_basic.py
@@ -47,9 +47,7 @@ class TDouble(CType):
%(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL;
//printf("Extracting %(name)s\\n");
- """ % dict(
- locals(), **sub
- )
+ """ % dict(locals(), **sub)
def c_sync(self, name, sub):
return f"""
diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py
index c9f40bbb7..cf4176626 100644
--- a/tests/link/c/test_op.py
+++ b/tests/link/c/test_op.py
@@ -82,9 +82,7 @@ class StructOp(COp):
return """
%(out)s = counter%(name)s;
counter%(name)s++;
-""" % dict(
- out=outputs_names[0], name=name
- )
+""" % dict(out=outputs_names[0], name=name)
def c_code_cache_version(self):
return (1,)
diff --git a/tests/link/c/test_type.py b/tests/link/c/test_type.py
index d12570351..9a87b1c54 100644
--- a/tests/link/c/test_type.py
+++ b/tests/link/c/test_type.py
@@ -29,9 +29,7 @@ Py_XDECREF((PyObject *)p);
Py_XDECREF(%(out)s);
%(out)s = (void *)%(inp)s;
Py_INCREF(%(inp)s);
-""" % dict(
- out=outs[0], inp=inps[0]
- )
+""" % dict(out=outs[0], inp=inps[0])
def c_code_cache_version(self):
return (0,)
@@ -58,9 +56,7 @@ Py_XDECREF((PyObject *)p);
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)%(inp)s;
Py_INCREF(%(out)s);
-""" % dict(
- out=outs[0], inp=inps[0]
- )
+""" % dict(out=outs[0], inp=inps[0])
def c_code_cache_version(self):
return (0,)
diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py
index 98bfbb610..1e240eab1 100644
--- a/tests/link/jax/test_nlinalg.py
+++ b/tests/link/jax/test_nlinalg.py
@@ -11,9 +11,8 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as at_blas
from pytensor.tensor import nlinalg as at_nlinalg
-from pytensor.tensor.math import MaxAndArgmax
+from pytensor.tensor.math import MaxAndArgmax, maximum
from pytensor.tensor.math import max as at_max
-from pytensor.tensor.math import maximum
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py
diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py
index 14bb06b77..0b4c34eeb 100644
--- a/tests/link/numba/test_tensor_basic.py
+++ b/tests/link/numba/test_tensor_basic.py
@@ -20,7 +20,7 @@ from tests.tensor.test_basic import TestAlloc
pytest.importorskip("numba")
-from pytensor.link.numba.dispatch import numba_funcify # noqa: E402
+from pytensor.link.numba.dispatch import numba_funcify # noqa: E402
rng = np.random.default_rng(42849)
diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py
index 698b1ff0b..6e3b925d9 100644
--- a/tests/scan/test_basic.py
+++ b/tests/scan/test_basic.py
@@ -39,9 +39,8 @@ from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.scan.utils import until
from pytensor.tensor.math import all as at_all
-from pytensor.tensor.math import dot, exp, mean, sigmoid
+from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tanh
from pytensor.tensor.random import normal
from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.shape import Shape_i, reshape, specify_shape
diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py
index 9dc6e698c..111ca37ac 100644
--- a/tests/scan/test_rewriting.py
+++ b/tests/scan/test_rewriting.py
@@ -18,9 +18,8 @@ from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise
-from pytensor.tensor.math import Dot, dot, sigmoid
+from pytensor.tensor.math import Dot, dot, sigmoid, tanh
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tanh
from pytensor.tensor.shape import reshape, shape, specify_shape
from pytensor.tensor.type import (
dmatrix,
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index 736019066..6b3967b08 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -50,11 +50,13 @@ from pytensor.tensor.math import (
minimum,
mul,
neq,
+ softplus,
+ sqrt,
+ sub,
+ true_div,
)
from pytensor.tensor.math import pow as at_pow
-from pytensor.tensor.math import softplus, sqrt, sub
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import (
assert_op,
local_alloc_sink_dimshuffle,
diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py
index ac4d293f1..a82bfaeed 100644
--- a/tests/tensor/rewriting/test_elemwise.py
+++ b/tests/tensor/rewriting/test_elemwise.py
@@ -4,9 +4,8 @@ import numpy as np
import pytest
import pytensor
-from pytensor import In
+from pytensor import In, shared
from pytensor import scalar as aes
-from pytensor import shared
from pytensor import tensor as at
from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode
@@ -23,9 +22,8 @@ from pytensor.scalar.basic import Composite, float64
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs
-from pytensor.tensor.math import add
-from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
+ add,
bitwise_and,
bitwise_or,
cos,
@@ -44,13 +42,20 @@ from pytensor.tensor.math import (
mul,
neg,
neq,
+ reciprocal,
+ sin,
+ sinh,
+ sqr,
+ sqrt,
+ tan,
+ tanh,
+ true_div,
+ xor,
)
+from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import pow as at_pow
-from pytensor.tensor.math import reciprocal
from pytensor.tensor.math import round as at_round
-from pytensor.tensor.math import sin, sinh, sqr, sqrt
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tan, tanh, true_div, xor
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
from pytensor.tensor.shape import reshape
@@ -1386,12 +1391,9 @@ class TimesN(aes.basic.UnaryScalarOp):
def c_support_code_apply(self, node, nodename):
n = str(self.n)
- return (
- """
+ return """
float %(nodename)s_timesn(float x) { return x * %(n)s; }
- """
- % locals()
- )
+ """ % locals()
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py
index 184879ba6..0490271d6 100644
--- a/tests/tensor/rewriting/test_math.py
+++ b/tests/tensor/rewriting/test_math.py
@@ -35,12 +35,13 @@ from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
-from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
-from pytensor.tensor.math import abs as at_abs
-from pytensor.tensor.math import add
-from pytensor.tensor.math import all as at_all
-from pytensor.tensor.math import any as at_any
from pytensor.tensor.math import (
+ Dot,
+ MaxAndArgmax,
+ Prod,
+ Sum,
+ _conj,
+ add,
arccosh,
arcsinh,
arctanh,
@@ -65,13 +66,12 @@ from pytensor.tensor.math import (
log1mexp,
log1p,
lt,
-)
-from pytensor.tensor.math import max as at_max
-from pytensor.tensor.math import maximum
-from pytensor.tensor.math import min as at_min
-from pytensor.tensor.math import minimum, mul, neg, neq, polygamma
-from pytensor.tensor.math import pow as pt_pow
-from pytensor.tensor.math import (
+ maximum,
+ minimum,
+ mul,
+ neg,
+ neq,
+ polygamma,
prod,
rad2deg,
reciprocal,
@@ -82,9 +82,17 @@ from pytensor.tensor.math import (
sqr,
sqrt,
sub,
+ tanh,
+ true_div,
+ xor,
)
+from pytensor.tensor.math import abs as at_abs
+from pytensor.tensor.math import all as at_all
+from pytensor.tensor.math import any as at_any
+from pytensor.tensor.math import max as at_max
+from pytensor.tensor.math import min as at_min
+from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tanh, true_div, xor
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.math import (
compute_mul,
@@ -590,7 +598,7 @@ class TestAlgebraicCanonizer:
# must broadcast as there is a dimshuffle in the computation
((dx * dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float64, row)>)]
- ((fx * fv) / fx, [fx, fv], [fxv, fvv], 1, "float32")
+ ((fx * fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float32, row)>)]
]
):
diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py
index 865da8313..7cf950436 100644
--- a/tests/tensor/rewriting/test_uncanonicalize.py
+++ b/tests/tensor/rewriting/test_uncanonicalize.py
@@ -9,9 +9,8 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import out2in
from pytensor.link.basic import PerformLinker
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
-from pytensor.tensor.math import MaxAndArgmax
+from pytensor.tensor.math import MaxAndArgmax, max_and_argmax
from pytensor.tensor.math import max as at_max
-from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.math import min as at_min
from pytensor.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle,
diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py
index 2c2b82d1b..bb24cd88f 100644
--- a/tests/tensor/test_basic.py
+++ b/tests/tensor/test_basic.py
@@ -3267,9 +3267,7 @@ def test_autocast_numpy():
def ok(z):
assert constant(z).dtype == np.asarray(z).dtype
- for x in (
- [2**i for i in range(63)] + [0, 0, 1, 2**63 - 1] + [0.0, 1.0, 1.1, 1.5]
- ):
+ for x in [2**i for i in range(63)] + [0, 0, 1, 2**63 - 1] + [0.0, 1.0, 1.1, 1.5]:
n_x = np.asarray(x)
# Make sure the data type is the same as the one found by numpy.
ok(x)
@@ -3298,9 +3296,7 @@ def test_autocast_numpy_floatX():
# into int64, as that is the maximal integer type that PyTensor
# supports, and that is the maximal type in Python indexing.
for x in (
- [2**i - 1 for i in range(64)]
- + [0, 0, 1, 2**63 - 1]
- + [0.0, 1.0, 1.1, 1.5]
+ [2**i - 1 for i in range(64)] + [0, 0, 1, 2**63 - 1] + [0.0, 1.0, 1.1, 1.5]
):
with config.change_flags(floatX=floatX):
ok(x, floatX)
diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py
index b4a8cc9ae..8a9dddfb7 100644
--- a/tests/tensor/test_elemwise.py
+++ b/tests/tensor/test_elemwise.py
@@ -19,10 +19,9 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
-from pytensor.tensor.math import Any, Sum
+from pytensor.tensor.math import Any, Sum, exp
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import any as pt_any
-from pytensor.tensor.math import exp
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.type import (
TensorType,
diff --git a/tests/tensor/test_keepdims.py b/tests/tensor/test_keepdims.py
index b3c7d1bb7..675ca2cd1 100644
--- a/tests/tensor/test_keepdims.py
+++ b/tests/tensor/test_keepdims.py
@@ -7,13 +7,10 @@ from pytensor.compile.mode import Mode
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import any as at_any
-from pytensor.tensor.math import argmax, argmin
+from pytensor.tensor.math import argmax, argmin, max_and_argmax, mean, prod, std, var
from pytensor.tensor.math import max as at_max
-from pytensor.tensor.math import max_and_argmax, mean
from pytensor.tensor.math import min as at_min
-from pytensor.tensor.math import prod, std
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import var
from pytensor.tensor.type import dtensor3
diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py
index af653c2f5..d2d77deef 100644
--- a/tests/tensor/test_math.py
+++ b/tests/tensor/test_math.py
@@ -110,9 +110,14 @@ from pytensor.tensor.math import (
sqr,
sqrt,
sub,
+ tan,
+ tanh,
+ tensordot,
+ true_div,
+ trunc,
+ var,
)
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tan, tanh, tensordot, true_div, trunc, var
from pytensor.tensor.type import (
TensorType,
complex_dtypes,
@@ -1414,7 +1419,7 @@ TestClip = makeTester(
np.array(2, dtype="uint16"),
np.array(4, dtype="uint16"),
),
- )
+ ),
# I can't think of any way to make this fail at runtime
)
diff --git a/tests/test_gradient.py b/tests/test_gradient.py
index 739289c3c..428720918 100644
--- a/tests/test_gradient.py
+++ b/tests/test_gradient.py
@@ -30,9 +30,8 @@ from pytensor.gradient import (
from pytensor.graph.basic import Apply, graph_inputs
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
-from pytensor.tensor.math import add, dot, exp, sigmoid, sqr
+from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh
from pytensor.tensor.math import sum as at_sum
-from pytensor.tensor.math import tanh
from pytensor.tensor.random import RandomStream
from pytensor.tensor.type import (
discrete_dtypes, Notes
- self.view_o = (
- {}
- ) # variable -> set of variables that use this one as a direct input
+ self.view_o = {} # variable -> set of variables that use this one as a direct input
[tool.ruff.lint.pycodestyle]
max-line-length = 88
|
I'm looking through the diff now and
# exclude: |
# (?x)^(
# .*/?__init__\.py|
# pytensor/graph/toolbox\.py|
# pytensor/link/jax/jax_dispatch\.py|
# pytensor/link/jax/jax_linker\.py|
# pytensor/scalar/basic_scipy\.py|
# pytensor/tensor/linalg\.py
# )$
Click to show Claude (LLM) reviewBased on the diff provided, here is a summary of some of the key changes from switching to ruff for formatting and linting:
So in summary, the main changes were due to differences in import sorting, line length limits, spacing rules, and other style conventions between the previous linter config and ruff's defaults. Let me know if you would like any specific area summarized in more detail! |
Hey @lmmx, great work! Sorry for questioning your pre-commit skills; I saw failing code style checks and was concerned maybe something was misconfigured. 😂 Looks like you're definitely on the right track, especially getting advice directly from the man himself! |
Superseded by #586 |
Motivation for these changes
As requested in #298 by @ricardoV94 and co-signed by @jessegrabowski and @cluhmann on the PyData Global code sprint Thursday call tonight
It looks like the previous PR to "add support for ruff" (#295) only fixed the issues raised by ruff, this PR will put it in the pre-commit.
Implementation details
ruff check
by addingnoqa
for unavoidable/'structural' ones (like import orders with dynamic imports) and fixed some to do with not duck typing (type(...) is ...
->isinstance(..., ...)
)Timing comparison
Here's the time it takes to run pre-commit (run on a low spec netbook!)
Total time 78 seconds
Here's the timing for pre-commit with ruff instead of black, flake8, and isort hooks.
Total time 20 seconds
Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance