diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe0c7013..c7ff4bc1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,5 +74,6 @@ jobs: run: | pip install . pip install pandoc + pre-commit install python tools/generate_opset.py git diff --exit-code diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ca0958a..619eea75 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,15 @@ repos: + - repo: https://github.com/Quantco/pre-commit-mirrors-ruff + rev: 0.2.1 + hooks: + - id: ruff-conda - repo: https://github.com/Quantco/pre-commit-mirrors-black - rev: 23.12.1 + rev: 24.1.1 hooks: - id: black-conda args: - --safe - --target-version=py38 - - repo: https://github.com/Quantco/pre-commit-mirrors-flake8 - rev: 6.1.0 - hooks: - - id: flake8-conda - - repo: https://github.com/Quantco/pre-commit-mirrors-isort - rev: 5.13.2 - hooks: - - id: isort-conda - additional_dependencies: [-c, conda-forge, toml=0.10.2] - repo: https://github.com/Quantco/pre-commit-mirrors-mypy rev: "1.8.0" hooks: @@ -27,7 +22,7 @@ repos: args: - --py38 - repo: https://github.com/Quantco/pre-commit-mirrors-prettier - rev: 3.1.1 + rev: 3.2.4 hooks: - id: prettier-conda files: "\\.md$" diff --git a/src/spox/_schemas.py b/src/spox/_schemas.py index 95bf7a80..dba3274a 100644 --- a/src/spox/_schemas.py +++ b/src/spox/_schemas.py @@ -1,4 +1,5 @@ """Exposes information related to reference ONNX operator schemas, used by StandardOpNode.""" + import itertools from typing import ( Callable, @@ -16,11 +17,9 @@ class _Comparable(Protocol): - def __lt__(self, other): - ... + def __lt__(self, other): ... - def __gt__(self, other): - ... + def __gt__(self, other): ... S = TypeVar("S") diff --git a/src/spox/_scope.py b/src/spox/_scope.py index 2d08a4e0..a8a793d9 100644 --- a/src/spox/_scope.py +++ b/src/spox/_scope.py @@ -59,12 +59,10 @@ def __contains__(self, item: Union[str, H]) -> bool: ) @overload - def __getitem__(self, item: H) -> str: - ... + def __getitem__(self, item: H) -> str: ... @overload - def __getitem__(self, item: str) -> H: - ... + def __getitem__(self, item: str) -> H: ... def __getitem__(self, item: Union[str, H]): """Access the name of an object or an object with a given name in this (or outer) namespace.""" @@ -76,12 +74,10 @@ def __getitem__(self, item: Union[str, H]): return self.name_of[item] @overload - def __setitem__(self, key: str, value: H): - ... + def __setitem__(self, key: str, value: H): ... @overload - def __setitem__(self, key: H, value: str): - ... + def __setitem__(self, key: H, value: str): ... def __setitem__(self, _key, _value): """Set the name of an object in exactly this namespace. Both ``[name] = obj`` and ``[obj] = name`` work.""" diff --git a/src/spox/_standard.py b/src/spox/_standard.py index f39801ee..b71960cb 100644 --- a/src/spox/_standard.py +++ b/src/spox/_standard.py @@ -1,4 +1,5 @@ """Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference.""" + from typing import TYPE_CHECKING, Callable, Dict, Tuple import numpy diff --git a/src/spox/_type_system.py b/src/spox/_type_system.py index ac111493..1427f9ff 100644 --- a/src/spox/_type_system.py +++ b/src/spox/_type_system.py @@ -40,9 +40,11 @@ def _from_onnx(cls, proto: onnx.TypeProto) -> "Type": if proto.HasField("tensor_type"): return Tensor( tensor_type_to_dtype(proto.tensor_type.elem_type), - Shape.from_onnx(proto.tensor_type.shape).to_simple() - if proto.tensor_type.HasField("shape") - else None, + ( + Shape.from_onnx(proto.tensor_type.shape).to_simple() + if proto.tensor_type.HasField("shape") + else None + ), ) elif proto.HasField("sequence_type"): return Sequence(Type._from_onnx(proto.sequence_type.elem_type)) diff --git a/tests/test_custom_operator.py b/tests/test_custom_operator.py index 4c25acbb..b3b8bbfa 100644 --- a/tests/test_custom_operator.py +++ b/tests/test_custom_operator.py @@ -5,6 +5,7 @@ for the respective fields ``attrs/inputs/outputs`` and ``infer_output_types`` will be useful as well. Of these, ``propagate_values`` is probably least common. """ + from dataclasses import dataclass from typing import Dict