Skip to content

Commit

Permalink
Merge pull request #427 from Pale-Blue-Dot-97/pre-commit-ci-update-co…
Browse files Browse the repository at this point in the history
…nfig

[pre-commit.ci] pre-commit autoupdate
  • Loading branch information
Pale-Blue-Dot-97 authored Jan 29, 2024
2 parents c37c3c4 + f0490ec commit 0a83023
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 53 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- id: requirements-txt-fixer

- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.1.1
hooks:
- id: black

Expand Down Expand Up @@ -59,7 +59,7 @@ repos:
- id: isort

- repo: https://github.com/PyCQA/bandit
rev: 1.7.6
rev: 1.7.7
hooks:
- id: bandit
args: [-c, pyproject.toml]
Expand Down
12 changes: 6 additions & 6 deletions minerva/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ def make_loaders(
if not isinstance(dataset_params["mask"].get("transforms"), dict):
dataset_params["mask"]["transforms"] = class_transform
else:
dataset_params["mask"]["transforms"][
"ClassTransform"
] = class_transform["ClassTransform"]
dataset_params["mask"]["transforms"]["ClassTransform"] = (
class_transform["ClassTransform"]
)

sampler_params: Dict[str, Any] = dataset_params["sampler"]

Expand Down Expand Up @@ -578,9 +578,9 @@ def make_loaders(
if type(dataset_params[mode]["mask"].get("transforms")) != dict:
dataset_params[mode]["mask"]["transforms"] = class_transform
else:
dataset_params[mode]["mask"]["transforms"][
"ClassTransform"
] = class_transform["ClassTransform"]
dataset_params[mode]["mask"]["transforms"]["ClassTransform"] = (
class_transform["ClassTransform"]
)

mode_sampler_params: Dict[str, Any] = dataset_params[mode]["sampler"]

Expand Down
10 changes: 6 additions & 4 deletions minerva/datasets/paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def __getnewargs__(self):
return self.dataset, self._args, self._kwargs

@overload
def __init__(self, dataset: Callable[..., GeoDataset], *args, **kwargs) -> None:
... # pragma: no cover
def __init__(
self, dataset: Callable[..., GeoDataset], *args, **kwargs
) -> None: ... # pragma: no cover

@overload
def __init__(self, dataset: GeoDataset, *args, **kwargs) -> None:
... # pragma: no cover
def __init__(
self, dataset: GeoDataset, *args, **kwargs
) -> None: ... # pragma: no cover

def __init__(
self,
Expand Down
6 changes: 2 additions & 4 deletions minerva/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,12 @@ def update_n_classes(self, n_classes: int) -> None:
@overload
def step(
self, x: Tensor, y: Tensor, train: bool = False
) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]:
... # pragma: no cover
) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: ... # pragma: no cover

@overload
def step(
self, x: Tensor, *, train: bool = False
) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]:
... # pragma: no cover
) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: ... # pragma: no cover

def step(
self,
Expand Down
6 changes: 3 additions & 3 deletions minerva/tasks/epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class StandardEpoch(MinervaTask):

def step(self) -> None:
# Initialises a progress bar for the epoch.
with alive_bar(
self.n_batches, bar="blocks"
) if self.gpu == 0 else nullcontext() as bar:
with (
alive_bar(self.n_batches, bar="blocks") if self.gpu == 0 else nullcontext()
) as bar:
# Sets the model up for training or evaluation modes.
if self.train:
self.model.train()
Expand Down
8 changes: 4 additions & 4 deletions minerva/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,12 @@ def __init__(
self.key = key

@overload
def __call__(self, sample: Tensor) -> Tensor:
... # pragma: no cover
def __call__(self, sample: Tensor) -> Tensor: ... # pragma: no cover

@overload
def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
... # pragma: no cover
def __call__(
self, sample: Dict[str, Any]
) -> Dict[str, Any]: ... # pragma: no cover

def __call__(
self, sample: Union[Tensor, Dict[str, Any]]
Expand Down
46 changes: 18 additions & 28 deletions minerva/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,12 @@ def __init__(self, *args, **kwargs) -> None:
self.keys = keys

@overload
def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]:
... # pragma: no cover
def __call__(
self, batch: Dict[str, Any]
) -> Dict[str, Any]: ... # pragma: no cover

@overload
def __call__(self, batch: Tensor) -> Dict[str, Any]:
... # pragma: no cover
def __call__(self, batch: Tensor) -> Dict[str, Any]: ... # pragma: no cover

def __call__(self, batch: Union[Dict[str, Any], Tensor]) -> Dict[str, Any]:
if isinstance(batch, Tensor):
Expand Down Expand Up @@ -417,8 +417,7 @@ def _optional_import(
*,
name: None,
package: str,
) -> ModuleType:
... # pragma: no cover
) -> ModuleType: ... # pragma: no cover


@overload
Expand All @@ -427,8 +426,7 @@ def _optional_import(
*,
name: str,
package: str,
) -> Callable[..., Any]:
... # pragma: no cover
) -> Callable[..., Any]: ... # pragma: no cover


@overload
Expand All @@ -437,8 +435,7 @@ def _optional_import(
*,
name: None,
package: None,
) -> ModuleType:
... # pragma: no cover
) -> ModuleType: ... # pragma: no cover


@overload
Expand All @@ -447,26 +444,23 @@ def _optional_import(
*,
name: str,
package: None,
) -> Callable[..., Any]:
... # pragma: no cover
) -> Callable[..., Any]: ... # pragma: no cover


@overload
def _optional_import(
module: str,
*,
name: str,
) -> Callable[..., Any]:
... # pragma: no cover
) -> Callable[..., Any]: ... # pragma: no cover


@overload
def _optional_import(
module: str,
*,
package: str,
) -> ModuleType:
... # pragma: no cover
) -> ModuleType: ... # pragma: no cover


def _optional_import(
Expand Down Expand Up @@ -679,8 +673,7 @@ def transform_coordinates(
y: Sequence[float],
src_crs: CRS,
new_crs: CRS = WGS84,
) -> Tuple[Sequence[float], Sequence[float]]:
... # pragma: no cover
) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover


@overload
Expand All @@ -689,8 +682,7 @@ def transform_coordinates(
y: float,
src_crs: CRS,
new_crs: CRS = WGS84,
) -> Tuple[Sequence[float], Sequence[float]]:
... # pragma: no cover
) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover


@overload
Expand All @@ -699,15 +691,13 @@ def transform_coordinates(
y: Sequence[float],
src_crs: CRS,
new_crs: CRS = WGS84,
) -> Tuple[Sequence[float], Sequence[float]]:
... # pragma: no cover
) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover


@overload
def transform_coordinates(
x: float, y: float, src_crs: CRS, new_crs: CRS = WGS84
) -> Tuple[float, float]:
... # pragma: no cover
) -> Tuple[float, float]: ... # pragma: no cover


def transform_coordinates(
Expand Down Expand Up @@ -1135,13 +1125,13 @@ def class_transform(label: int, matrix: Dict[int, int]) -> int:
@overload
def mask_transform( # type: ignore[overload-overlap]
array: NDArray[Any, Int], matrix: Dict[int, int]
) -> NDArray[Any, Int]:
... # pragma: no cover
) -> NDArray[Any, Int]: ... # pragma: no cover


@overload
def mask_transform(array: LongTensor, matrix: Dict[int, int]) -> LongTensor:
... # pragma: no cover
def mask_transform(
array: LongTensor, matrix: Dict[int, int]
) -> LongTensor: ... # pragma: no cover


def mask_transform(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,9 @@ def test_lat_lon_to_loc(
except GeocoderUnavailable:
pass

with no_connection(), pytest.raises(
GeocoderUnavailable, match="Geocoder unavailable"
with (
no_connection(),
pytest.raises(GeocoderUnavailable, match="Geocoder unavailable"),
):
_ = utils.lat_lon_to_loc(lat, lon)

Expand Down

0 comments on commit 0a83023

Please sign in to comment.